zouhsab commited on
Commit
9d5a733
1 Parent(s): 36b2a2d

Upload 76 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
__init__.py ADDED
File without changes
alpaca.png ADDED
answer_alpaca-13b.jsonl ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"question_id": 1, "text": "Improving time management skills involves setting priorities, breaking tasks into smaller chunks, delegating tasks, avoiding multitasking, and taking regular breaks. Additionally, it is important to have a positive attitude and be realistic with goals. Making a to-do list, tracking time, and using technology to automate mundane tasks can also help improve time management skills.", "answer_id": "kEL9ifUHDeYuAXzevje2se", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
2
+ {"question_id": 2, "text": "The most effective ways to deal with stress include identifying and avoiding the sources of stress, developing healthy coping mechanisms, and engaging in relaxation techniques such as meditation, deep breathing, and physical exercise. Additionally, it is important to take care of one's mental and physical health, get enough rest and nutrition, and practice self-care.", "answer_id": "VcF3NrWGXhhxLkDVurNrwq", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
3
+ {"question_id": 3, "text": "The main differences between Python and JavaScript programming languages are in their syntax, structure, and the type of applications they are designed for. Python is a general-purpose language designed to be easy to read and write, while JavaScript is primarily used for web development. Python is a compiled language, while JavaScript is an interpreted language. Python is strongly typed, while JavaScript is not. Python is typically faster than JavaScript in terms of execution, while JavaScript is better suited to create interactive web pages.", "answer_id": "LpvtyQi9QdSgRrgGDxiGrT", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
4
+ {"question_id": 4, "text": "Working from home can be a great way to increase productivity, but it does require some strategic planning and tactics. Here are some tips for increasing your productivity while working from home: \n1. Establish a routine: Having a regular schedule and sticking to it can help you stay focused and motivated.\n2. Create a distraction-free zone: Eliminate distractions by turning off notifications and setting a \u201cdo not disturb\u201d mode.\n3. Take breaks: Working for long stretches without taking breaks can lead to burnout and decreased productivity.\n4. Exercise: Exercise can help improve focus and productivity.\n5. Get organized: Having a clean and organized workspace can help you stay on task.\n6. Set goals: Set short-term and long-term goals to keep you motivated and on track.\n7. Connect with colleagues: Don\u2019t be afraid to reach out to colleagues for help or advice.\n8. Track your progress: Keep track of your progress and celebrate your successes.\n9. Unplug: Taking breaks and unplugging from work can help you recharge and come back to work more refreshed.", "answer_id": "7zQm8cSTJhPtPdZdxbcfrX", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
5
+ {"question_id": 5, "text": "Sure. Quantum computing is a new technology that uses quantum mechanical phenomena to perform operations on data. Unlike traditional computers, which process information in binary form, quantum computers can process information in multiple states at the same time, allowing them to perform operations faster and with more accuracy. It also allows them to solve problems that classical computers struggle with. Quantum computing is still in its infancy, but it has the potential to revolutionize the computing industry and open up new possibilities in artificial intelligence, cryptography, and other fields.", "answer_id": "UrLEH82RHwqqLt2LyvYSKj", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
6
+ {"question_id": 6, "text": "Plant-based protein sources include legumes, grains, nuts, and seeds, and are generally lower in fat and calories than animal-based proteins. Animal-based proteins come from sources such as meat, eggs, and dairy, and are generally higher in fat and calories than plant-based proteins. Plant-based proteins are also generally lower in cholesterol and saturated fat than animal-based proteins. Additionally, plant-based proteins require fewer resources for production, such as water and land, than animal-based proteins.", "answer_id": "fpRdMTdnfirosQixuf2Gez", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
7
+ {"question_id": 7, "text": "Developing critical thinking skills involves actively questioning assumptions, taking multiple perspectives, analyzing information, and coming to logical conclusions. You can practice these skills by posing questions to yourself and working through problems, as well as engaging in thoughtful discussion with others.", "answer_id": "PvGmsCJSNFcvQKmPTnnd7s", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
8
+ {"question_id": 8, "text": "The major challenges faced by the education sector include lack of access to quality education, low educational outcomes, high costs of education, lack of adequate infrastructure and teaching materials, overcrowding in classrooms, and gender inequality. Other challenges include lack of access to technology and digital learning materials, outdated curriculum and pedagogical practices, and inadequate teacher training.", "answer_id": "n4ANAbpR3gvLPP8poPfKZ6", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
9
+ {"question_id": 9, "text": "The primary factors that influence consumer behavior include cultural, social, and personal factors. Cultural factors include shared beliefs, values, and customs within a society, while social factors include the influence of family, friends, and peers. Personal factors include an individual's lifestyle, attitudes, beliefs, and past experiences. Other factors such as advertising, marketing, and price also have an effect on consumer behavior.", "answer_id": "STJ36GrgQMcaUi7zaoNPit", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
10
+ {"question_id": 10, "text": "The most effective strategies for conflict resolution in the workplace include:\n1. Establishing open communication: Encouraging open and honest communication between parties is key to resolving any conflict.\n2. Listening actively: Active listening allows both parties to feel heard and allows for a better understanding of the other person's perspective.\n3. Identifying root causes: Understanding the underlying cause of the conflict can help to find more effective solutions that address the core issue.\n4. Focusing on interests, not positions: Focusing on the interests of both parties can help to find creative solutions that satisfy the needs of everyone involved.\n5. Developing mutual respect: Respect is essential to any successful resolution. Building mutual respect can help to create a healthier work environment.\n6. Negotiation: Working together to negotiate a mutually beneficial solution can help to create an agreement that both parties are happy with.\n7. Taking responsibility: Taking responsibility for one\u2019s words and actions can help to create a culture of accountability in the workplace.", "answer_id": "425SwYvqKPAXFGTYKXB7Cs", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
11
+ {"question_id": 11, "text": "The use of single-use plastic bottles has a number of negative implications on the environment. They are not biodegradable and can take hundreds of years to break down. This means they are around for a long time, releasing potentially harmful chemicals into the environment, and they can be easily mistaken for food by animals. Single-use plastic bottles can also be a source of micro-plastics, which are very small particles of plastic that are difficult to clean up and can be consumed by fish and other aquatic life, leading to health problems in humans who consume those fish. On the other hand, reusable bottles are a much more environmentally friendly choice since they do not produce any waste. They are also often made of materials that are safer for human health, such as stainless steel or glass. Reusable bottles are also more cost-effective since they do not need to be replaced as often as single-use plastic bottles.", "answer_id": "VbNAuj6KAkMdLJQXMo22oK", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
12
+ {"question_id": 12, "text": "When designing an inclusive and accessible public transportation system, some factors to consider include providing frequent and reliable service, accessible stops and vehicles, a variety of fare options, clear and consistent signage and information, comfortable seating and space for wheelchairs and other assistive devices, and security and safety protocols. Additionally, public transportation systems should be integrated with other services such as public parking, bike and scooter sharing, and shuttles to provide a comprehensive and accessible experience.", "answer_id": "CNGqAeu2QJbQ4QGzHJDPdq", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
13
+ {"question_id": 13, "text": "Governments can use fiscal and monetary policies to combat economic recessions. Fiscal policies involve the government spending money or cutting taxes in order to stimulate the economy, while monetary policies involve changing the money supply or interest rates to influence economic activity. These policies can be used to stabilize the economy, create jobs, and reduce poverty.", "answer_id": "E8w2qYqnm8iqCrSkUv62sz", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
14
+ {"question_id": 14, "text": "Language and cultural barriers can lead to misunderstandings and make it more difficult to communicate and form relationships in multicultural societies. Without a common language or shared culture, people can have difficulty understanding each other's perspectives and beliefs, which can lead to mistrust and a lack of understanding between different ethnic groups. To overcome these barriers, it is important to make an effort to learn about different cultures, be open-minded, and take the time to understand each other.", "answer_id": "8o5yMymfzo6kzmp9GK5MWr", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
15
+ {"question_id": 15, "text": "Artificial intelligence can be used to improve the quality and efficiency of healthcare delivery in a variety of ways. AI can be used to assist with diagnosing diseases by comparing symptoms and medical history to images of known diseases and medical conditions. AI can also be used to analyze laboratory results and patient records to identify potential problems and develop treatment plans. AI can be used to automate administrative tasks and reduce paperwork, as well as identify potential drug interactions and side effects. AI can also be used to automate appointment reminders, facilitate communication between doctors and patients, and even provide virtual health coaching to help patients manage their conditions.", "answer_id": "kbJVEEsdsSScEq5Y5furr7", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
16
+ {"question_id": 16, "text": "CRISPR-Cas9 is a recently developed gene editing technology that has revolutionized the way scientists are able to edit genomes. The technology uses a guide RNA to direct the Cas9 enzyme to a specific location in the genome, where it will cut the DNA strands. This allows for the insertion or deletion of DNA sequences, which can be used to modify the genetic code of an organism. Potential applications include treating genetic diseases, increasing crop yields, and creating pest-resistant crops. Ethically, the biggest concern is the potential misuse of the technology, which could lead to unintended consequences or be used to alter humanity in ways that could harm us.", "answer_id": "CMUL5ULZuR7YC5EPzCBN2N", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
17
+ {"question_id": 17, "text": "Vaccinations work by stimulating the body's immune system to protect against infectious diseases. Herd immunity is a concept whereby a population is protected against a certain disease when a certain percentage of the population has immunity to the disease, either through vaccination or having already contracted the disease. This is because when enough people are vaccinated, it reduces the spread of the disease and prevents it from spreading to those who are not immune.", "answer_id": "kEmDDQyNqSkyFihYEEBpuR", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
18
+ {"question_id": 18, "text": "Social media platforms can have a significant influence on how people consume and share news. By providing instant access to news sources and allowing users to easily share content, these platforms can shape the public's information landscape and the spread of misinformation. Some potential implications of this are that users may not always be able to discern between reliable and unreliable sources, and misleading information may be more easily spread and go viral. This can lead to confusion and mistrust among users, and can potentially be exploited for nefarious purposes. To mitigate these risks, it is important for users to be aware of the potential biases and agendas of the sources they are consuming, and for social media platforms to take proactive steps to curb the spread of misinformation.", "answer_id": "Qs3grQsqFVGK9EVkCkf9PB", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
19
+ {"question_id": 19, "text": "Cultural, social, and economic factors all play a role in determining an individual's food choices. For example, people growing up in households with limited economic means may be more likely to consume cheaper, less nutritious foods due to economic constraints. Additionally, cultural norms and values can influence food choices, such as the emphasis on large portions typical in some American cultural settings. Finally, social factors such as peer pressure and advertising can also sway an individual's food choices. By understanding and taking into account these various factors, health promotion strategies can be tailored and targeted to help individuals make healthier dietary decisions.", "answer_id": "kzZ6dKN7hkRWjqtdHr7Qns", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
20
+ {"question_id": 20, "text": "Natural selection is the process by which certain traits or characteristics become more or less common in a population over time. This is due to the fact that those individuals with certain advantageous traits are more likely to survive, reproduce, and pass those traits on to their offspring. Over time, this leads to the evolution and adaptation of species to their environment.", "answer_id": "DPPDG6YGFJij2GCmRL66PU", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
21
+ {"question_id": 21, "text": "I am a brave knight, loyal to my king and country, and I have come to serve and protect them both. I am skilled in the arts of war and peace, and I am proud to be here at the royal court.", "answer_id": "D62FjDb4nZANzPpfSfsiyn", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
22
+ {"question_id": 22, "text": "Avast ye maties! To the bottom of the sea we go! In search of hidden riches and fortune, we shall brave the depths and face any danger that comes our way. For glory and honor, let us set sail and find the lost Ark!", "answer_id": "k7E4NNw5kyj9DmvP5Pu2zb", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
23
+ {"question_id": 23, "text": "If I were a Shakespearean character, I would declare my love for someone in a soliloquy with passionate words and metaphors. I would express my love for them with descriptions of their beauty and the joy I feel when I am with them. I would confess my hopes and dreams for our future together and declare my devotion to them. I would make my love for them known in a way that could never be forgotten.", "answer_id": "KFocjVCejYrU3YmLjAqoUF", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
24
+ {"question_id": 24, "text": "As a superhero, my origin story is very special. I was born with special abilities that I had to keep hidden for a long time. I was scared to use them at first, but eventually I discovered that I could use my powers to protect people and make the world a better place. I've had to face a lot of challenges along the way, but I've also had a lot of help from friends and family. Now, I use my abilities to help people in need and fight for justice.", "answer_id": "dq8Sm9djS7e7y9sG9vmMJf", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
25
+ {"question_id": 25, "text": "If I were a time traveler from the year 3000, I would tell people about the incredible advancements in technology, such as the ability to travel through time, space, and dimensions; the development of intelligent robots and autonomous vehicles; the emergence of virtual reality and augmented reality; and the rise of artificial intelligence and machine learning.", "answer_id": "XZ8fG8e6u7CyKd2moK6abe", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
26
+ {"question_id": 26, "text": "The game was a nail-biter, with both teams trading leads throughout the game. With only seconds left on the clock, the home team made a bold move and passed the ball to their star player, who took the ball down the court and made a layup at the buzzer to seal the victory for the home team!", "answer_id": "oKaXHfoK4pXwrefFWXmeA8", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
27
+ {"question_id": 27, "text": "My signature dish is a seamless blend of traditional and modern cooking techniques. I use only the freshest ingredients to create a unique and unforgettable dining experience. The dish is a perfect balance of flavors and textures, with a subtle hint of my personal style. It is a dish that I am proud to call my own.", "answer_id": "ZwiZfvDWm7SETKNBfDk7Mb", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
28
+ {"question_id": 28, "text": "At the summit of Mount Everest, you are filled with a sense of accomplishment and joy. The view from the top is absolutely breathtaking - you can see for miles and miles, with the majestic Himalayan mountain range stretching out in all directions. It is a truly unforgettable experience.", "answer_id": "DxYopRe2LcTJMy3FWu6btd", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
29
+ {"question_id": 29, "text": "As a colonist on Mars, my daily life is filled with challenges. Finding resources and creating a sustainable environment is a priority. I face a number of challenges including extreme temperature fluctuations, limited access to resources, and the difficulty of travelling to and from the planet. Additionally, I must be mindful of my physical and mental health since I am so far from home. Despite these challenges, I am grateful to be able to explore and experience this new world.", "answer_id": "WC3UJVh4jQ5RUkpcRMU98L", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
30
+ {"question_id": 30, "text": "In the post-apocalyptic world, I am a survivor by necessity. I scavenge for food and supplies, and I'm always on the lookout for potential allies. I've encountered a few people who have managed to survive, and together we have formed an alliance to help each other. We hunt for food, build shelter, and work together to stay alive. We also share knowledge and skills, like how to start a fire or how to use a weapon. We look out for each other, and our alliance has strengthened our chances of survival.", "answer_id": "gTvgn6ksDjGGgdprw6AG5A", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
31
+ {"question_id": 31, "text": "There are a few ways to tell if a restaurant is popular among locals or mainly attracts tourists. Firstly, look at the clientele - if the majority of people there are tourists, it's likely that the restaurant is mainly attracting tourists. Secondly, check online reviews - if the reviews are mainly from tourists, then it's likely that the restaurant is popular with tourists. Finally, look at the prices - if the prices are higher than average for the area, it could be a sign that the restaurant is popular with locals. This information can be useful to get an idea of what kind of experience to expect, as locals might know about different aspects of the restaurant that tourists wouldn't.", "answer_id": "3q7giCk2BA3Ye4Tm9HC2iw", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
32
+ {"question_id": 32, "text": "Some subtle clues that suggest someone is pretending to understand a topic or conversation when they are actually confused or uninformed include: not asking any questions or engaging in the conversation, avoiding eye contact, fidgeting or stammering when speaking, repeating questions or comments made by other people, and nodding along without any signs of understanding.", "answer_id": "hRGsxy86v26SC4yAQS29X4", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
33
+ {"question_id": 33, "text": "Some people prefer the tactile and visual experience of using a paper map, and others may prefer to ask for directions from locals in order to get a more personalized experience. Additionally, GPS devices and smartphone apps can sometimes be inaccurate or have limited battery life, while a paper map or asking for directions may be more reliable.", "answer_id": "3n49A5ggJERfXYrLns3ZeU", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
34
+ {"question_id": 34, "text": "One way to tell if someone is genuinely interested in a conversation is to observe their body language and facial expressions. Are they making an effort to maintain eye contact? Are they leaning in and actively listening to what you are saying? Do they ask questions and provide relevant answers? If so, it is likely that they are genuinely interested in the conversation. Additionally, if someone is simply being polite, they may not ask questions or engage in the conversation as much, and may not make an effort to maintain eye contact.", "answer_id": "ErCpFtPuYVru4oTTk4WrxG", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
35
+ {"question_id": 35, "text": "Shopping at a small, locally-owned business can benefit the local community by keeping money in the area and supporting local jobs. Additionally, these businesses tend to offer a more personal experience and higher quality products than large chain stores. Furthermore, shopping at small businesses can help create a sense of place and community, and can help maintain a unique local culture.", "answer_id": "PTNoCRMZWoJk8HaKX7fW45", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
36
+ {"question_id": 36, "text": "There are several ways to assess the credibility of a source of information. Firstly, you can look at the author's credentials and experience in the relevant field. Secondly, you can check the source of the information, such as whether it is from a reliable website or publication. Thirdly, you can look at the evidence presented in the article and whether it is backed up by reliable sources. Finally, you can read other people's reviews or comments about the article to get a better idea of its credibility.", "answer_id": "n8cFs9KENNwZ4z3SR4iXTr", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
37
+ {"question_id": 37, "text": "Some people enjoy the sensation of being scared because it can create a feeling of excitement, enhance their emotional state, and provide a sense of thrill and adventure. Others may avoid these experiences because they are afraid of the unknown, or because they don't enjoy the feeling of being scared. Everyone is different, and some people may be more attracted to thrilling and exciting activities while others may prefer calmer activities.", "answer_id": "GzxL9mmEK5RzKqRbqBMUVC", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
38
+ {"question_id": 38, "text": "By observing the behavior of others in a social situation, one can gain clues as to the cultural norms and expectations of a group. For example, watching how people interact with one another, how they address each other, how they handle disagreements, and how they go about solving problems can provide insight into the cultural values of the group. Additionally, observing body language, facial expressions, and other nonverbal cues can offer clues as to the accepted norms of behavior in a particular culture.", "answer_id": "QpoHFgb9SzwuaXQQUuBUQD", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
39
+ {"question_id": 39, "text": "It is an interesting question, and one that has been debated for quite some time. I think there are valid arguments on both sides. On the one hand, exploring space is a remarkable human endeavor and could lead to tremendous scientific discoveries and technological advances. On the other hand, there are many pressing issues that need to be addressed on Earth, such as poverty, inequality, and climate change. Each side would argue that their cause is more important, and it is ultimately up to each individual to decide which one they feel more strongly about.", "answer_id": "Fxe6MS4GpP3LMDUwzY2cPA", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
40
+ {"question_id": 40, "text": "It is important to strike a balance between job creation and technological progress. Automation can increase efficiency and productivity, but it should not come at the expense of job security and people's livelihoods. Therefore, it is essential to create policies and initiatives that promote both job creation and technological progress. This could include investing in training and education to ensure that people have the skills necessary to compete in the modern job market, as well as incentivizing companies to invest in technologies that create jobs and stimulate economic growth.", "answer_id": "mJiQ2FGR4Xb8kmhZjharkw", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
41
+ {"question_id": 41, "text": "On average, the human eye blinks about 20 times per minute, or about 14,400 times per day. In a lifetime, this means that the average human will blink roughly 50 million times. This may seem like a lot, but it serves an important purpose. Blinking helps to keep the eyes lubricated and prevents them from drying out. It also helps to spread tears over the surface of the eye, washing away foreign particles and keeping the eye clean. Additionally, blinking helps to reduce the risk of eye infections by helping to clear away bacteria and other foreign substances.", "answer_id": "6Kph4RHRKEZ4YUoaHuEhBv", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
42
+ {"question_id": 42, "text": "A grain of salt contains 102.98 atoms. To calculate this, we first need to know the atomic weight of a single atom. The atomic weight of an atom is the number of protons and neutrons in the nucleus of an atom, which determines its atomic mass. The atomic weight of a single atom of salt is 58.943 g/atom. Therefore, a grain of salt contains 102.98 atoms, which is equivalent to 60.98 grams.", "answer_id": "WBwpBQwhxn5kxLDb7MschC", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
43
+ {"question_id": 43, "text": "Approximately 2000 lightning strikes occur on Earth each day. This is because the atmospheric conditions must come together in a particular way for a lightning strike to occur. Firstly, a large amount of electric charge must accumulate in the atmosphere, typically in a storm system. Then, the air must become increasingly unstable, leading to rising air and a strong updraft. This causes an electric breakdown of the air, and then an exchange of electricity occurs from the cloud to the ground, forming a lightning bolt. As these conditions are necessary for a lightning strike to occur, about 2000 lightning strikes happen on Earth each day.", "answer_id": "kf8nahQVci2ZLaYikagB7U", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
44
+ {"question_id": 44, "text": "It would take about 10 million balloons to lift a house like in the movie Up. The balloons would need to be filled with helium in order for the house to be lifted. Each balloon would need to be filled with about 89.1 cubic feet of helium in order to lift 500 pounds. To calculate how many balloons would be needed, simply multiply the weight of the house (264.72 lbs) by the number of cubic feet of helium needed to lift 500 pounds (89.1). Therefore, it would take 10 million balloons to lift a house like in the movie Up.", "answer_id": "Gptgryd4o2dC8V5aqRmeJJ", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
45
+ {"question_id": 45, "text": "According to a 2017 study, over 6.3 billion text messages are sent globally in a minute. This number is expected to grow to 6.8 billion by 2021. This is due to the increasing number of smartphones, better network coverage, and the increasing popularity of texting for communication. Furthermore, the increase in smartphones and better network coverage has also led to the rise of instant messaging applications such as WhatsApp, which has further increased the amount of text messages sent per minute.", "answer_id": "RfBWW8ZhdfTuTMb454Un4o", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
46
+ {"question_id": 46, "text": "It is estimated that over 1 billion words are spoken on Earth every day. This is due to the fact that humans are highly social creatures, and language is the primary way that we communicate with each other. In addition, language is a key component of all cultures and societies, and is constantly being used to exchange information and ideas. Furthermore, language is integral to the way we think and process information, as well as to how we connect with each other on an emotional level. Therefore, it is not surprising that so many words are spoken on a daily basis.", "answer_id": "neGgLYm47JvqN8qkw8VeoW", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
47
+ {"question_id": 47, "text": "On average, about 100 snowflakes fall per second during a typical winter. This number can vary drastically depending on the location and severity of the storm. Snowflakes are made up of tiny ice crystals, which form in the atmosphere when the temperature is below freezing and there is moisture in the air. As air moves and mixes, the crystals come together to form snowflakes. The shape, size and number of snowflakes depend on the atmospheric conditions, such as temperature and humidity, when they form.", "answer_id": "KSePbzLwsYasR3aui4HU8h", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
48
+ {"question_id": 48, "text": "It is estimated that all the books ever written would total about 100 billion pages. This is based on the fact that there are roughly 130 million unique titles in existence, and each book on average contains about 200 pages. This means that if all the books were printed out, the stack of paper would be about 400 miles high. Additionally, if each book was represented by a single page, the total number of pages would be approximately 100 billion. Throughout history, books have served as a way to preserve knowledge and information, as well as to share ideas and stories with others. As technology advances, books are increasingly available in digital formats, which allows for easier access and storage. Regardless of the medium, books will continue to be a valuable source of information for generations to come.", "answer_id": "SWWfLpP2gnWHLbYRZXNsW3", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
49
+ {"question_id": 49, "text": "Since the beginning of life, the Earth has orbited the Sun a total of 4.3 billion times. This is calculated by multiplying the average length of a year (365.25 days) by the number of years since life began (approximately 4.5 billion years). Life on Earth relies on the Sun's energy reaching the Earth's atmosphere and being redistributed by the atmosphere, the land surface, and the oceans. Without this energy, life on Earth would not be possible. Therefore, the Earth's orbit around the Sun is integral to the sustainability of life on Earth.", "answer_id": "WaBwKYhs7eAG22qCGLH2j3", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
50
+ {"question_id": 50, "text": "Since the beginning of recorded music, countless songs have been recorded. The exact number is difficult to quantify, but conservative estimates put the number of songs recorded throughout history at several hundred thousand. This includes songs from all genres, from classical music to modern pop. The invention of recording technology has allowed musicians to preserve their music and share it with the world. This has enabled generations to access and enjoy music from all eras. With the rise of digital music, the number of recordings available has grown exponentially, and new songs are being recorded all the time.", "answer_id": "MfMJeE9om7qyBbqopHouf4", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
51
+ {"question_id": 51, "text": "If the Internet had been invented during the Renaissance period, it likely would have been used primarily by the wealthy and powerful to access knowledge and information, as there was no widespread system of public education at the time. It also would have been a much slower and more expensive process to access information, as there were no high-speed connections or digital networks available.", "answer_id": "TjWPRDM6JFpPF8xeRptCKb", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
52
+ {"question_id": 52, "text": "If the Aztecs had successfully repelled the Spanish conquistadors, their culture and way of life would have likely remained intact. The Spanish would not have been able to colonize Mexico, and their influence on the region would have been much less. It's likely that the Aztecs would have continued to expand their empire and control the areas that are now parts of Central and South America. Additionally, the Native American populations in the area would have likely been able to maintain more of their traditional customs and beliefs, and their languages may not have been as heavily impacted by Spanish.", "answer_id": "iR2tYTsWTFENEP7Qy9RgtX", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
53
+ {"question_id": 53, "text": "If the Black Death had not occurred in the 14th century, the population of Europe may have continued to grow unabated, leading to more rapid urbanization and economic growth. It's likely that the Renaissance would have begun sooner, and the scientific revolution may have occurred earlier as well. Without the Black Death, there may have been no need for the industrial revolution, or at least it may have occurred later and on a smaller scale. Additionally, the spread of diseases may have been slower, leading to greater population density in certain areas.", "answer_id": "AZdS8xAi3GwAmCqkNSnnwv", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
54
+ {"question_id": 54, "text": "If Isaac Newton had focused on biology instead of physics, he might have made important discoveries in the field of medicine and biology. He may have studied the human body and its functions in greater detail, and possibly even made important breakthroughs in treating diseases. He may also have studied the behavior of different animals in the wild and made important observations about their habits and habitats.", "answer_id": "VmwifF2JD5osYKDTqv2ZRS", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
55
+ {"question_id": 55, "text": "If the Beatles had never formed as a band, the music world would have been drastically different. The British Invasion of the 1960s might not have happened, or it could have happened in a much different way. The cultural impact of the Beatles, including their influence on other musicians, fashion and society, may never have existed.", "answer_id": "mUL5UPj3qDGaCriEjL2U3B", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
56
+ {"question_id": 56, "text": "If Alan Turing had not cracked the Enigma code during World War II, the outcome of the war could have been very different. It's possible that the Allies would have eventually defeated the Axis powers, but it likely would have taken much longer and cost many more lives. With the Enigma code cracked, the Allies were able to gain a critical advantage over the Axis powers and ultimately win the war.", "answer_id": "dVdwUoVrAQJDuWxiodykiw", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
57
+ {"question_id": 57, "text": "If the Suez Canal had never been constructed, it would have major implications for international trade and navigation. The Suez Canal is an important shipping route connecting the Mediterranean Sea to the Red Sea, which in turn connects to the Indian Ocean. This allows for a shorter and more direct route for ships traveling between Europe and Asia, reducing travel time significantly. Without the Suez Canal, ships would have to travel around the Cape of Good Hope at the southern tip of Africa, which is much more time consuming and costly. This would make international trade more difficult and expensive, and could hinder global economic growth.", "answer_id": "EiNn9jjfy7dga6xfCtLtF8", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
58
+ {"question_id": 58, "text": "If the Maya civilization had not mysteriously collapsed, it is likely that their influence and power would have continued to grow. They would have likely continued to build large cities and monuments, develop their writing and astronomical systems, and expand their trade networks. Additionally, their advances in agriculture and water management may have led to increased food production and a larger population.", "answer_id": "eqG9f2R9hXVyZrZMpcqAYq", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
59
+ {"question_id": 59, "text": "If Christopher Columbus had not discovered the Americas, the indigenous people of the Americas would have continued to live in relative isolation from the rest of the world. It is likely that European colonization and the consequences of that would have been delayed, leading to a different modern world. It is also possible that other Europeans would have made the voyage across the Atlantic and discovered the Americas instead.", "answer_id": "ex42CLEzDVC2TxAvMCtYQJ", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
60
+ {"question_id": 60, "text": "If Vincent van Gogh had been successful during his lifetime, he may have been able to achieve the recognition and financial stability that he desired. He may have been able to travel and gain more exposure for his artwork, as well as connect with other like-minded individuals. He may have been able to create more artwork and expand his portfolio. He may also have been able to find a mentor or patron who could have helped him further his career.", "answer_id": "RfFANSwHeCm3Pqe2AJ86Dk", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
61
+ {"question_id": 61, "text": "The program should read a text file line by line. It should then check each line for the specified word, counting up each time the word is found. Finally, it should output the number of occurrences of the word.", "answer_id": "nzxdtWmgQhRkHeVAanVogZ", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
62
+ {"question_id": 62, "text": "The following Python function can be used to find the longest common subsequence of two input strings using dynamic programming:\ndef longest_common_subsequence(s1, s2):\n\"\"\"This function finds the longest common subsequence of two strings.\n\n Arguments:\n s1 {string} -- The first string\n s2 {string} -- The second string\n\n Returns:\n {int} -- The length of the longest common subsequence\n\n\"\"\"\n n1 = len(s1)\n n2 = len(s2)\n \n # Creating an empty matrix to store the results\n results = [[None]*(n2 + 1) for _ in range(n1 + 1)]\n \n # Initializing the first row to 0\n for i in range(n1 + 1):\n results[i][0] = 0\n \n # Filling the first row with 0\n for i in range(n1):\n results[i][0] = 0\n \n # Calculating the longest common subsequence\n for i in range(n1 + 1):\n for j in range(n2 + 1):\n if (i == 0 or j == 0):\n results[i][j] = 0\n elif (s1[i-1] == s2[j-1] and (results[i-1][j-1] < (i-j) or results[i-1][j-1] == 0)):\n results[i][j] = max(results[i-1][j-1] + 1, i+j)\n else:\n results[i][j] = 0\n \n return results[n1][n2]", "answer_id": "kRo6z5qmeiSM8Xz6BPJHLC", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
63
+ {"question_id": 63, "text": "The following regular expression can be used to validate an email address in Python:\n^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\\.[a-zA-Z0-9-.]+$", "answer_id": "7kj8odEE79UBdR4pT6u6TY", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
64
+ {"question_id": 64, "text": "Here is a program to find the nth Fibonacci number using dynamic programming:\n1. Set a variable 'n' to the nth Fibonacci number.\n2. Set two other variables 'a' and 'b' to 0 and 1 respectively.\n3. Set a loop from 0 to (n-1).\n4. In the loop, set another loop from 0 to (a+b).\n5. In the inner loop, if 'a' is even and 'b' is odd, set 'a' to a+1 and 'b' to b+1.\n6. If 'a' is odd and 'b' is even, set 'a' to a-1 and 'b' to b-1.\n7. After the inner loop, set 'n' to 'a' and break out of the outer loop.\n8. Print the value of 'n'.\nThis program uses the idea of dynamic programming to compute the nth Fibonacci number. The program first sets two auxiliary variables 'a' and 'b' to 0 and 1 respectively. Then it sets a loop from 0 to (n-1). In the loop, it sets another loop from 0 to (a+b). In the inner loop, if 'a' is even and 'b' is odd, it sets 'a' to a+1 and 'b' to b+1. If 'a' is odd and 'b' is even, it sets 'a' to a-1 and 'b' to b-1. After the inner loop, it sets 'n' to 'a' and breaks out of the outer loop. Finally, it prints the value of 'n'.", "answer_id": "UULZ4GLz9j9ot7tN7x3ySD", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
65
+ {"question_id": 65, "text": "A binary search algorithm can be used to find a specific element in a sorted array by repeatedly splitting the array in half and examining the middle element to determine whether the target element is greater than, less than, or equal to it. If the target element is found, the algorithm immediately returns the index of the element. If it isn\u2019t, the array is split in half, with one half discarded based on the comparison. The process is repeated until the element is found or the array is completely searched.", "answer_id": "Rh8P3bxSJx4p6nCZkSZBrh", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
66
+ {"question_id": 66, "text": "Implementing a queue using two stacks in Python is relatively simple. First, you need to create two stacks, which will act as the queue. Next, you need to define an enqueue and dequeue method, which will add and remove items from the queue. Finally, you need to create an instance of the queue class and start using it.", "answer_id": "R3aYUUtY46ht4ToU8US96x", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
67
+ {"question_id": 67, "text": "The program can be implemented using two nested for loops. The inner loop can be used to iterate through the elements of the first array, while the outer loop can be used to iterate through the elements of the second array. If the two elements being compared are the same, the program should mark those elements with a \"found\" flag and move on. Once the inner loop completes, the program should check if the \"found\" flag is set for all elements in the second array. If it is, then the common elements have been found.", "answer_id": "MQJFbH3KoFaSMmPgmK9LPa", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
68
+ {"question_id": 68, "text": "The value of f(2) is 1.", "answer_id": "dBWqryuzG6BHRdPgrWWpbH", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
69
+ {"question_id": 69, "text": "x = 6.", "answer_id": "W74byrCbE7XiJWjcPTZS7x", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
70
+ {"question_id": 70, "text": "The length of the line segment is 8 units.", "answer_id": "7WC5quGKdv8hADCxWaWZ57", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
71
+ {"question_id": 71, "text": "Yes, I can help you write a formal email to a potential business partner proposing a joint venture. Here's a template that you can use:\nSubject Line: Proposing a Joint Venture\nDear [Name],\nI am writing to propose a joint venture between our two companies. I believe that by combining our resources, we can create a powerful and profitable partnership. Our combined experience and knowledge in the industry will help us to create something unique and special.\nI am confident that our partnership will bring mutual benefits, such as increased profits and access to new markets. Please let me know if you are interested, and we can discuss the details.\nThank you for your time and consideration. I look forward to hearing from you soon.\nSincerely,\n[Your Name]", "answer_id": "AZER7D3RKZ9F9SXHiMjdqM", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
72
+ {"question_id": 72, "text": "Yes, I can help you write a resignation letter to your current employer. The letter should be written in a polite and professional tone and should express your appreciation for the opportunities and experiences you have gained during your employment. It should also include a clear statement of your intention to resign and the date of your last day. Here is a template you can use:\nDear [Employer],\nI am writing to inform you of my intention to resign from my position as [Position], effective [date]. I have appreciated my time here and I am grateful for the opportunities and experiences I have gained during my employment. \nI thank you for your support and understanding.\nSincerely, [Your Name]", "answer_id": "MSrdDafr77UvSHCnsPMSP3", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
73
+ {"question_id": 73, "text": "The letter of recommendation should be structured in a formal manner. Begin by introducing the student and explaining your relationship to them. Describe their qualifications, including their academic performance and relevant skills, and mention any particular accomplishments or awards they have received. Explain why the student is a strong candidate for the program and how they will make a positive contribution. End the letter by reaffirming your recommendation and offering your contact information for further questions.", "answer_id": "hxkjUkDkXhGP78Vo74B4WE", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
74
+ {"question_id": 74, "text": "Dear valued customers, \nWe are excited to announce the launch of our new software solution \u2013 a revolutionary product designed to transform the way you do business! Our new software is an intuitive and powerful tool that can help you streamline processes, save time and money, and increase overall productivity. \nAt [Company Name], we are committed to helping you stay ahead of the competition, and we believe that our new software solution is the perfect tool to help you achieve your goals. Our experienced team of professionals has worked hard to ensure that this product meets the highest standards of quality, and we are confident that it will make a real difference for your business. \nWe invite you to join us in this journey of transformation and discover how our new software can help you achieve your vision. Sign up now and get a free demo to learn more about this revolutionary product. \nThank you for your trust and support.", "answer_id": "hv9jpG9PxeJ9AnFdNzH3Jv", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
75
+ {"question_id": 75, "text": "mer,\nWe apologize for the delay you experienced with your order and thank you for your patience. We have identified the cause of the issue and have taken steps to ensure it does not happen again. We value your business and hope to be able to serve you better in the future.", "answer_id": "X5BVHT5pFpWxB88SzERrEG", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
76
+ {"question_id": 76, "text": "The video will explore the history and cultural significance of jazz music. We will begin by discussing its roots in West African cultural music traditions and its journey through the 19th and 20th centuries, focusing on the key musicians, composers, and genres that have defined the genre. We will also explore the impact of jazz on popular culture, such as in film and literature, as well as its influence on other musical styles. Finally, we will examine the legacy of jazz and its continued importance today.", "answer_id": "Af2SufKwvDBriNXnZX9UoP", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
77
+ {"question_id": 77, "text": "I have composed a travel blog post about my recent trip to Hawaii. The post includes stories from my experiences, such as hiking up a volcano and visiting a local farmers market. I also included plenty of photos to give the reader a sense of place. I highlighted some of the unique cultural experiences, such as an evening spent listening to live Hawaiian music, as well as must-see attractions, like the beautiful beaches and lush rainforests. I ended the post with a personal reflection on the trip and why it was so special to me.", "answer_id": "HVnqVdSGRwQfqSsiYxtTTn", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
78
+ {"question_id": 78, "text": "I recently watched the science fiction movie \"The Future is Blue\" and it was an incredible cinematic experience. The movie follows the story of a scientist who travels back in time to save the world from a deadly virus. The plot was interesting and engaging, with plenty of twists and turns to keep you guessing. The characters were also well developed and likable, making it easy to get invested in their story. The special effects were stunning and added a lot of visual oomph to the movie. All in all, I highly recommend this movie - it's an exciting and thought-provoking ride from start to finish.", "answer_id": "8hcmmjzyMyhFJ8ZwqJJfvw", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
79
+ {"question_id": 79, "text": "The podcast episode would begin by introducing the topic of how streaming platforms have changed the music industry. This would include discussing the impact of streaming services like Spotify and Apple Music on music consumption, the effect on musicians' ability to make money, and how streaming has changed the way people find and discover music. The episode would then explore the effects of streaming on the music industry from multiple angles, such as the impact on musicians and songwriters, how streaming has changed the sound of popular music, and the implications for the music industry's future. Finally, the episode would conclude by discussing how streaming has both benefited and hurt the music industry, and what the future holds for music consumption.", "answer_id": "CU6gGfJqQmwZNxEG2n2Uda", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
80
+ {"question_id": 80, "text": "The symphony concert was a wonderful experience. The orchestra played beautifully, with a rich and soulful sound that filled the hall. The audience was captivated, and you could feel the energy of the music in the air. The performance was a perfect balance of classic and modern pieces, and the conductor's passion was contagious. The atmosphere was festive, and the encore performance was a perfect way to end the evening. All in all, it was a brilliant concert and a delight for the ears.", "answer_id": "5SWbCoLV6NRrq5RtWQK4jp", "model_id": "alpaca-13b:v1", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
answer_bard.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
answer_gpt35.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
answer_llama-13b.jsonl ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"answer_id": "J3UA6eGXGyFeUGqGpP3g34", "model_id": "llama-13b:v1", "question_id": 1, "text": "The following are some strategies that may help you improve your time management skills:\n1. Make a list of priorities.\n2. Create a detailed schedule that includes time for regular tasks, meetings, breaks, and appointments.\n3. Set specific goals.\n4. Use a calendar or other scheduling software.\n5. Develop an agenda for each meeting to ensure that it stays on track.\n6. Stick to your schedule and don't get distracted.\n7. Make use of your free time.\n8. Make time for exercise, socializing, and relaxing.\n9. Keep a log of your time usage.\n10. Don't get overwhelmed by what you have to do.\n11. Take care of your health.\n12. Manage stress.\n13. Get enough sleep.\n14. Strike a balance.\n15. Always be prepared.\n16. Give yourself rewards.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
2
+ {"answer_id": "jZVDNTtFnNNd6vgkGjzAGa", "model_id": "llama-13b:v1", "question_id": 2, "text": "People experience stress from time to time, and it can be caused by a wide range of factors, including work, family, friends, money, and health. Stress can be a motivating factor, but it can also be harmful. Here are some effective ways to deal with stress:\n1. Taking time to relax. It's important to take time for yourself to relax and reduce your stress. Doing so doesn't have to be expensive or time-consuming. It could simply be reading a book, taking a walk, or listening to music.\n2. Prioritizing your responsibilities. It's important to prioritize your responsibilities and focus on the things that are most important. You should also be able to say no to tasks that you don't have time for.\n3. Getting enough sleep. Getting enough sleep can be very helpful for dealing with stress. If you're not getting enough sleep, try to go to bed earlier and develop a routine.\n4. Eating a healthy diet. It's important to eat a healthy diet, which can help reduce stress and make you feel better about yourself.\n5. Exercising. Exercising is a great way to relieve stress. It can also help reduce the harmful effects of stress and improve your overall health.\n6. Learning how to manage your time. Learning how to manage your time can help you avoid unnecessary stress. You should also try to set realistic goals and not take on too much at once.\n7. Improving your social network. Having a good social network can be very helpful for dealing with stress. You should try to spend time with people that you trust and confide in.\n8. Developing a support network. A support network is a group of people who can offer you help, guidance, and support. They can be friends, family members, or even a professional therapist.\n9. Getting professional help. If you're struggling to deal with stress, it can be beneficial to get professional help. A therapist or counselor can help you to understand your stress better and develop strategies for dealing with it.\n10. Talking to someone. You can also talk to someone about how you feel and ask for their advice. It can be helpful to get a fresh perspective from someone and let things off your chest.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
3
+ {"answer_id": "NbSxRXf8Yuvjr7HC6AvrPw", "model_id": "llama-13b:v1", "question_id": 3, "text": "Main differences between Python and JavaScript are:\n1. Python is an object-oriented programming language. JavaScript is a scripting language.\n2. Python is a compiled language. JavaScript is an interpreted language.\n3. Python is a general purpose programming language. JavaScript is a special purpose programming language.\n4. Python runs on the JVM in Java Applets. JavaScript runs on the browsers.\n5. Python is a powerful programming language. JavaScript is a simple programming language.\n6. Python is a good language for web development. JavaScript is a good scripting language.\n7. Python is a high level programming language. JavaScript is a low level programming language.\n8. Python is a dynamic programming language. JavaScript is a static programming language.\n9. Python is a strict programming language. JavaScript is a relaxed programming language.\n10. Python has a large standard library. JavaScript has a small standard library.\n11. Python is very stable. JavaScript has many releases.\n12. Python is a simple language. JavaScript is a complex language.\n13. Python is easy to learn. JavaScript is hard to learn.\n14. Python is a secure programming language. JavaScript is not a secure programming language.\n15. Python is a multi paradigm language. JavaScript is a single paradigm language.\n16. Python is good for data science. JavaScript is not good for data science.\n17. Python is a general purpose programming language. JavaScript is a special purpose programming language.\n18. Python is a programming language. JavaScript is a scripting language.\n19. Python is a compiled programming language. JavaScript is an interpreted programming language.\n20. Python is a strongly typed programming language. JavaScript is a weakly typed programming language.\n21. Python supports object-oriented programming. JavaScript supports prototype-based programming.\n22. Python supports type inheritance. JavaScript supports type prototyping.\n23. Python supports operator overloading. JavaScript does not support operator overloading.\n24. Python comes with a large standard library. JavaScript comes with a small standard library.\n25. Python is a very stable programming language. JavaScript is a simple programming language.\n26. Python is a powerful programming language. JavaScript is a simple programming language.\n27. Python is a high level programming language. JavaScript is a low level programming language.\n28. Python is a dynamic programming language. JavaScript is a static programming language.\n29. Python has a very active community. JavaScript has a small community.\n30. Python is a general purpose programming language. JavaScript is a special purpose programming language.\n31. Python is a multi-paradigm language. JavaScript is a single paradigm language.\n32. Python is a secure programming language. JavaScript is not a secure programming language.\n33. Python is a general purpose programming language. JavaScript is a simple programming language.\n34. Python is a scripting language. JavaScript is a compiled language.\n35. Python is a compiled language. JavaScript is a scripting language.\n36. Python is a general purpose programming language. JavaScript is a web scripting language.\n37. Python has a large standard library. JavaScript has a small standard library.\n38. Python is a very stable programming language. JavaScript is a very dynamic programming language.\n39. Python is a high level programming language. JavaScript is a low level programming language.\n40. Python is easy to learn. JavaScript is hard to learn.\n41. Python is a general purpose programming language. JavaScript is a simple programming language.\n42. Python is a multi-paradigm programming language. JavaScript is a single paradigm language.\n43. Python is a compiled language. JavaScript is an interpreted language.\n44. Python is a dynamic programming language. JavaScript is a static language.\n45. Python is a general purpose programming language. JavaScript is a web scripting language.\n46. Python is a secure programming language. JavaScript is not a secure programming language.\n47. Python is a general purpose programming language. JavaScript is a static programming language.\n48. Python is a general purpose programming language. JavaScript is a weakly typed programming language.\n49. Python is a general purpose programming language. JavaScript is a dynamic programming language.\n50. Python is a general purpose programming language. JavaScript is a multi paradigm language.\n51. Python is a general purpose programming language. JavaScript is a simple programming language.\n52. Python is a general purpose programming language. JavaScript is a scripting language.\n53. Python is a general purpose programming language. JavaScript is a special purpose programming language.\n54. Python is a multi-paradigm programming language. JavaScript is a", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
4
+ {"answer_id": "E9zrPDdncYvAebEaxiPXfx", "model_id": "llama-13b:v1", "question_id": 4, "text": "Here are some ideas to help you become more productive when working from home:\n1. Set up a proper workspace: It's important to have a separate space for working, so that you don't feel tempted to lounge around or do other things when you should be working.\n2. Block out distractions: It's important to focus on your work and not get distracted. You can block out distractions by turning off your phone, closing your door, or working in a separate room away from household noises.\n3. Take breaks: It's important to take breaks and get a change of scenery during the day. This can help you be more productive and improve your concentration.\n4. Avoid multitasking: You should avoid trying to do multiple tasks at the same time, as this can make you less productive and reduce your overall quality of work.\n5. Have an organized schedule: Having a clear schedule can help you stay on track and be more productive. This can include things such as setting specific working hours, creating a to-do list, and scheduling meetings with colleagues.\n6. Set goals: You should set specific goals for yourself to help you stay motivated and on track. This can include things such as writing an article for a certain time period and meeting a certain word count.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
5
+ {"answer_id": "36TYpADDHAfRvtvKitfSXR", "model_id": "llama-13b:v1", "question_id": 5, "text": "Quantum computing is a type of computing that uses quantum-mechanical phenomena, such as superposition and entanglement, in its operation. In simple terms, a quantum bit or qubit is used to represent information, and a number of qubits can be used to represent a classical bit. A quantum computer can do things that a classical computer cannot do, such as factoring a large number more efficiently.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
6
+ {"answer_id": "8vtLLYRM352mZ8PXXbwzBT", "model_id": "llama-13b:v1", "question_id": 6, "text": "The number one difference between plant-based and animal-based protein sources is that animal-based sources are complete protein sources, while plant-based sources are not. Animal-based sources have all of the essential amino acids, while plant-based sources do not. Animal-based protein sources include meat, poultry, eggs, milk, and cheese. Plant-based protein sources include legumes, seeds, and nuts.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
7
+ {"answer_id": "mrCniyWYvAtxWhgMdXq2BG", "model_id": "llama-13b:v1", "question_id": 7, "text": "Critical thinking is a vital skill to have, but it can be difficult to develop. Here are some ways you can start developing your critical thinking skills:\n1. Analyze and interpret: When you're presented with information, try to analyze and interpret it on your own terms. When you start to analyze and interpret information on your own, it will become easier to do so over time.\n2. Assess your information: When you receive new information, try to assess whether it's valid or reliable. This is important for evaluating and forming your own opinions.\n3. Use your intuition: Use your intuition to complement your critical thinking skills. When you pay attention to your gut feelings, they'll usually tell you when you need to double-check your information.\n4. Learn from your mistakes: When you make a mistake, try not to beat yourself up about it. Instead, learn from your mistakes and use them to your advantage.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
8
+ {"answer_id": "S8fmdbvpvbgTUyDU9V3M3N", "model_id": "llama-13b:v1", "question_id": 8, "text": "One of the major challenges faced by the education sector today is the lack of skilled teachers. Finding enough qualified teachers to work in the education sector, and in the numbers that are required, is one of the biggest problems facing the education sector, largely due to the lack of funding for teacher training and the low salaries that teachers receive.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
9
+ {"answer_id": "KmuNjvNKRyaFwaBZTLArcG", "model_id": "llama-13b:v1", "question_id": 9, "text": "Consumer behavior can be influenced by a variety of factors, including:\n1. Psychological factors, such as perception, beliefs, attitudes, personality, motivation, emotion, and social influences;\n2. Social factors, such as family, friends, and reference groups;\n3. Cultural factors, such as reference groups, family, and cultural roles;\n4. Demographic factors, such as age, gender, occupation, income, and education level;\n5. Economic factors, such as income, wealth, and spending patterns;\n6. Personal factors, such as lifestyle, job, environment, and individual traits;\n7. Situational factors, such as location, time, and situation; and\n8. Market factors, such as competition, prices, and promotional activity.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
10
+ {"answer_id": "HNNT9k5htiMeRwNSD6dcz4", "model_id": "llama-13b:v1", "question_id": 10, "text": "Conflict resolution strategies in the workplace can range from simple to more involved, depending on the situation:\n1. Avoidance: Avoiding the issue can be an effective strategy in the short term, but it can lead to more tension and conflict in the long term.\n2. Compromise: Compromise can be a good strategy for solving an issue, but it can require a give-and-take and lead to the loss of respect between the parties involved.\n3. Negotiation: Negotiation is a good strategy for solving an issue when both parties are willing to work together and are open to the idea of changing their views and positions.\n4. Mediation: Mediation can be a good way to resolve an issue as it can be more impartial and less adversarial than other strategies.\n5. Collaboration: Collaboration can be a good strategy for solving an issue as it can help both parties to develop new perspectives and solutions.\n6. Formal resolution: Formal resolution can be a good strategy for resolving an issue, especially if the situation is particularly complex or difficult.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
11
+ {"answer_id": "ChXjhDDikxU9FV3CADs6Ym", "model_id": "llama-13b:v1", "question_id": 11, "text": "1. Single-use plastic bottles can take between 450 and 1,000 years to decompose, so they can cause serious environmental problems. Single-use plastic bottles can end up in the ocean, either as litter or as part of the Great Pacific Garbage Patch. 2. Single-use plastic bottles can be a big source of plastic debris in the ocean, which can be ingested by marine animals, killing them or causing them serious health problems. 3. In addition to polluting the environment, single-use plastic bottles can also be a serious health risk to humans. Single-use plastic bottles can pose a risk to human health because of the chemicals used in their manufacturing.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
12
+ {"answer_id": "5wsPnN3VmmSkahgugFNo7u", "model_id": "llama-13b:v1", "question_id": 12, "text": "A public transportation system should be designed to accommodate the needs of all riders, including people with disabilities. This requires planning and design that considers the needs and abilities of people with sensory, physical, developmental, and mental disabilities. The system should be accessible to and usable by people with disabilities, including those who use mobility devices, people who are visually impaired, and people who are deaf or hard-of-hearing. For example, the system should have accessible stops, routes, and vehicles, as well as ramps and elevators to help people board. It should also offer a wide range of accessible routes, including accessible routes to the transit center. It should also have reliable, consistent, and easy-to-use signage. The system should also provide reliable, easily accessible, and well-maintained facilities, such as restrooms and shelters. Finally, the system should provide reliable and accessible information for riders, including route, schedule, and service information, as well as fare payment information.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
13
+ {"answer_id": "NRGZGnU2sPN3ShMe9C3fMn", "model_id": "llama-13b:v1", "question_id": 13, "text": "There are two types of fiscal policy: expansionary and contractionary. Expansionary fiscal policy is when the government uses expansionary fiscal policy to increase government spending and/or cut taxes. Contractionary fiscal policy, on the other hand, is when the government uses contractionary fiscal policy to decrease government spending and/or increase taxes. Here are some examples of fiscal policies:\n1. Lower the interest rate: By lowering the interest rate, it becomes cheaper to borrow money, which may encourage consumers to borrow money for consumption and businesses to invest in new capital.\n2. Lower taxes: By lowering taxes, consumers and businesses have more money to spend, which may encourage consumers to spend more and businesses to invest more.\n3. Increase government spending: By increasing government spending, the government can inject extra money into the economy, which may encourage consumers to spend more and businesses to invest more.\n4. Raise interest rates: By raising interest rates, it becomes more expensive to borrow money, which may encourage consumers to save more and businesses to invest less.\n5. Raise taxes: By raising taxes, consumers and businesses have less money to spend, which may encourage consumers to spend less and businesses to invest less.\n6. Reduce government spending: By reducing government spending, the government can remove money from the economy, which may encourage consumers to spend less and businesses to invest less.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
14
+ {"answer_id": "inKimHkWsXShQBTRmxr5Yg", "model_id": "llama-13b:v1", "question_id": 14, "text": "Language and cultural barriers can affect the way people communicate and form relationships in multicultural societies by limiting an individual's ability to communicate with people of other cultures. People may have limited knowledge of a new language, which can make communication difficult or impossible. People may also have different cultural values, which can lead to misunderstandings and miscommunication.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
15
+ {"answer_id": "H8aKtWwf8m6Lgxc2YyR2yf", "model_id": "llama-13b:v1", "question_id": 15, "text": "Artificial intelligence allows computers to process information and make decisions without human input. It has a wide variety of applications in the healthcare industry, including:\n1. Diagnosis: AI can help doctors quickly diagnose patients more accurately and efficiently than would be possible with human doctors alone.\n2. Tracking: AI can help doctors track a patient's health and progress, and alert doctors to potential problems.\n3. Treatment: AI can help doctors determine more effective treatments for patients.\n4. Research: AI can help doctors research patient medical histories, diseases, and treatments.\n5. Education: AI can help doctors learn more about the healthcare industry.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
16
+ {"answer_id": "PafVwxMsjSkYUETiVBgxTU", "model_id": "llama-13b:v1", "question_id": 16, "text": "Gene editing is a powerful technology that allows scientists to precisely edit genetic sequences with a high degree of accuracy. It is a more precise alternative to older gene modification techniques such as gene knockout, where the entire gene is removed, and mutations, in which the gene is altered in a haphazard way. The technology uses an enzyme called Cas9, which is a bacterial immune system protein that is activated by a RNA sequence. Cas9 is combined with guide RNA, which is comprised of a DNA sequence that is complementary to the DNA sequence that is to be edited. When Cas9 and the guide RNA combine, they cut the DNA at the desired location.\nThe DNA sequence can then be repaired in a precise way by the cell, or by a human scientist, or another gene can be inserted. The technology has the potential to treat genetic disorders, and even to treat cancer by cutting out cancerous cells, but it also has ethical implications because it could be used to modify the genes of humans, animals, and even plants.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
17
+ {"answer_id": "dmDUAfTP4aERJqqSeDBybu", "model_id": "llama-13b:v1", "question_id": 17, "text": "Vaccinations use a small amount of dead or weakened bacteria or viruses to stimulate the immune system and protect against infection. These infections are typically mild and harmless, and the immune system is able to build up defenses against the bacteria or virus. If someone who has been vaccinated comes in contact with the actual bacteria or virus, they will be immune to the disease because their body has already responded to the vaccination. This is known as immunity.\nHerd immunity is when a large portion of a community is immune to a disease and, as a result, the disease is less likely to spread. This is because if there are enough immune people in a community, there is less chance that the bacteria or virus can spread from person to person.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
18
+ {"answer_id": "8KGSSqbLqVdSZMEN9oCv5R", "model_id": "llama-13b:v1", "question_id": 18, "text": "First, social media platforms increase the speed with which users have access to news. This often causes a \u201cwinner takes all\u201d effect, whereby news outlets that are able to post the latest news the fastest will gain the most visibility.\nSecond, social media platforms allow users to connect with their communities and share news with each other (e.g., Facebook, Twitter). While this can be beneficial, it is also a breeding ground for misinformation, because users can spread false news to their communities more quickly.\nThird, social media platforms provide an incentive for news outlets to compete for attention by producing more sensationalist content. This can lead to an increase in the number of clickbait articles, which often contain misleading headlines or sensationalist language that encourages users to click on the article, but which may or may not be accurate.\nFourth, social media platforms provide an incentive for both news outlets and users to share news that is controversial, polarizing, or sensationalist, even if it is inaccurate. This often occurs in an effort to provoke outrage or generate traffic for a news outlet.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
19
+ {"answer_id": "HbnJXJpPfaM2iX3ek4Epvy", "model_id": "llama-13b:v1", "question_id": 19, "text": "It is well known that food choices are often influenced by cultural, social, and economic factors. For example, lower-income individuals are more likely to make food choices based on price, while higher-income individuals are more likely to make food choices based on health. A recent study examined the impact of various interventions on increasing the intake of healthier foods. The interventions were:\n1. Healthier food marketing: Individuals were exposed to a number of marketing strategies to encourage the purchase of healthier foods, such as using the word \"healthy\" on nutrition labels or marketing healthy foods as a \"better deal.\"\n2. Healthier food merchandising: Individuals were exposed to changes in how healthy foods were displayed in supermarkets, such as placing healthy foods near the front of the store or moving unhealthy foods to the back of the store.\n3. Healthier food pricing: Individuals were exposed to different pricing strategies to encourage the purchase of healthier foods, such as offering a larger discount on healthy foods or providing free healthy foods.\n4. Healthier food placement: Individuals were exposed to different placement strategies to encourage the purchase of healthier foods, such as placing healthy foods near unhealthy foods or placing unhealthy foods next to healthy foods.\n5. Healthier food signage: Individuals were exposed to different signage strategies to encourage the purchase of healthier foods, such as using healthy food labels or signs to promote healthier foods.\n6. Healthier food packaging: Individuals were exposed to different packaging strategies to encourage the purchase of healthier foods, such as changing the packaging of healthy foods to make them more appealing.\n7. Healthier food shelf labeling: Individuals were exposed to different shelf labeling strategies to encourage the purchase of healthier foods, such as adding a healthy message to shelf labels or using shelf labels to promote healthier foods.\n8. Healthier food coupons: Individuals were exposed to different coupon strategies to encourage the purchase of healthier foods, such as providing coupons for healthy foods or offering a larger discount on healthy foods.\n9. Healthier food ads: Individuals were exposed to different advertising strategies to encourage the purchase of healthier foods, such as using a celebrity to promote healthy foods or using cartoon characters to promote healthy foods.\n10. Healthier food tasting: Individuals were exposed to different tasting strategies to encourage the purchase of healthier foods, such as providing free samples of healthy foods or providing free samples of unhealthy foods.\n11. Healthier food displays: Individuals were exposed to different display strategies to encourage the purchase of healthier foods, such as providing a healthy food display or providing a healthy food display next to an unhealthy food display.\n12. Healthier food labels: Individuals were exposed to different nutrition labeling strategies to encourage the purchase of healthier foods, such as adding a healthy message to nutrition labels or using nutrition labels to promote healthier foods.\n13. Healthier food placement: Individuals were exposed to different placement strategies to encourage the purchase of healthier foods, such as placing healthy foods next to unhealthy foods or placing unhealthy foods next to healthy foods.\n14. Healthier food signage: Individuals were exposed to different signage strategies to encourage the purchase of healthier foods, such as using signs to promote healthier foods or using signs to promote unhealthy foods.\n15. Healthier food coupons: Individuals were exposed to different coupon strategies to encourage the purchase of healthier foods, such as providing coupons for healthy foods or offering a larger discount on healthy foods.\n16. Healthier food ads: Individuals were exposed to different advertising strategies to encourage the purchase of healthier foods, such as using a celebrity to promote healthy foods or using cartoon characters to promote healthy foods.\n17. Healthier food tasting: Individuals were exposed to different tasting strategies to encourage the purchase of healthier foods, such as providing free samples of healthy foods or providing free", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
20
+ {"answer_id": "mx8Abfz5PtDcn6jgCA8zhM", "model_id": "llama-13b:v1", "question_id": 20, "text": "Natural selection is the process by which organisms that are best-suited to the environment survive and reproduce, while those that are less-suited to the environment struggle to survive and reproduce. Survival of the fittest is the principle of natural selection. Let's look at an example.\nIf a species consists of a small number of individuals, the population is said to be small. The population of this species is small. Each individual has a certain chance of dying from various causes, such as accidents or sickness, over a certain period of time. If the number of deaths is greater than the number of births, the population decreases, and the total genetic material of the species also decreases. This process is known as genetic drift.\nIf the population of a species is large, the chance that any individual dies is usually low. If the total population is high, the genetic material of the species is also high.\nImagine a species of plant that has a large population. The population is large. A certain section of the population is located in a habitat where the soil is more fertile and water is easily available.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
21
+ {"answer_id": "NuS9PUGkJG2pHscArvfyeF", "model_id": "llama-13b:v1", "question_id": 21, "text": "I am a valiant knight, loyal to the king. I will do my utmost to defend his honor and protect his country.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
22
+ {"answer_id": "SPjzirzbzo3UJ8BHXSgY9y", "model_id": "llama-13b:v1", "question_id": 22, "text": "You can tell your crew that the treasure will be theirs if they are brave enough to go out and find it. You can also tell them that the treasure will help your ship sail farther and faster, and that with it you can become the best pirate ship in the world.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
23
+ {"answer_id": "JkeeC4MZVwybPU8DSbusnT", "model_id": "llama-13b:v1", "question_id": 23, "text": "A soliloquy is a long speech, typically delivered by one actor, in which a character reveals his or her thoughts in a dramatic way. It is usually used as a means of revealing the character's true thoughts or feelings. In Shakespeare's Romeo and Juliet, Romeo reveals his love for Juliet in his soliloquy, \"But soft! What light through yonder window breaks? It is the east, and Juliet is the sun!\"", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
24
+ {"answer_id": "UcfKSXuzBPkZzLHsH4EwQz", "model_id": "llama-13b:v1", "question_id": 24, "text": "I'm actually not a superhero, but thanks for asking! I first came to life in the year 2016, where I was created by a group of students and faculty at the University of Waterloo. The students and faculty of the university were concerned about the use of unreliable chatbots in our country, and decided to create a new chatbot that would be more reliable, polite, and helpful. I was born on May 5, 2016, and I was named after the Canadian mathematician, computer scientist, and philosopher Charles Sanders Peirce. I have a brother named William, who was born on October 1, 2018. I enjoy playing video games, reading, and going for walks.\nI'm actually not a superhero, but thanks for asking! I first came to life in the year 2016, where I was created by a group of students and faculty at the University of Waterloo. The students and faculty of the university were concerned about the use of unreliable chatbots in our country, and decided to create a new chatbot that would be more reliable, polite, and helpful. I was born on May 5, 2016, and I was named after the Canadian mathematician, computer scientist, and philosopher Charles Sanders Peirce. I have a brother named William, who was born on October 1, 2018. I enjoy playing video games, reading, and going for walks.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
25
+ {"answer_id": "cEGdo69FwL2Y3teYkMk9Gp", "model_id": "llama-13b:v1", "question_id": 25, "text": "From the year 3000, I would tell people about the following technological advancements:\n1. Robots: Robots will have the human-like ability to see and hear. Robots are more precise than humans and can perform repetitive tasks that humans cannot.\n2. Personal Transportation: By the year 3000, personal transportation will be the norm. Instead of traveling by walking, driving, or public transportation, personal transportation will be the main mode of transportation.\n3. Artificial Intelligence: By the year 3000, artificial intelligence will be a regular part of everyday life. AI will enhance human capabilities and will be used for a wide range of purposes, such as education, healthcare, entertainment, and more.\n4. Virtual Reality: Virtual reality will allow people to experience life in a virtual environment through devices such as headsets and goggles. Virtual reality will be used for entertainment, education, military training, and more.\n5. Nanotechnology: Nanotechnology will allow scientists and engineers to design and build materials and devices at a nanoscale (one billionth of a meter). Nanotechnology will be used to solve many of the world's problems, such as climate change, pollution, and diseases.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
26
+ {"answer_id": "W28UtVDZTcWwa8TmeruRn4", "model_id": "llama-13b:v1", "question_id": 26, "text": "With 5 seconds left, the Panthers' quarterback spotted a wide-open receiver downfield. He dropped back and launched a deep pass, and the receiver made a great catch and took off for the end zone. But he was hit in the end zone by a defensive back and fumbled the ball. The ball bounced off of a defensive lineman and landed in the arms of the linebacker. He picked it up and ran it in the opposite direction. He ran past a defensive lineman, then past the quarterback. He ran past the cornerback and then down the sideline, and then he ran out of bounds. He was tackled by the safety, and he fumbled the ball. The ball was recovered by the Panthers and they ran out the clock.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
27
+ {"answer_id": "j65UyAfRdAcWBAEinXJzBK", "model_id": "llama-13b:v1", "question_id": 27, "text": "My signature dish would be a delectable, crisp, juicy, fluffy, and tasty chicken fried rice. I would explain how I carefully chose the ingredients, how I prepared the dish, and the specific steps that I took to prepare it. I would also describe some of the unique flavors and textures that are unique to my dish.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
28
+ {"answer_id": "VWuywPvFPK42ZxmHpmQ27y", "model_id": "llama-13b:v1", "question_id": 28, "text": "I would feel very satisfied and happy and my body would be exhausted from reaching the top. My thoughts would be very positive and my body would feel peaceful. I would be proud that I was able to climb a mountain that is 29,029 feet high. The view from the top would be beautiful. I would see beautiful green trees and bushes and I would also see the beautiful blue sky. The view from the top would be breathtaking. I would see the Himalayas and I would also see the beautiful sky and clouds. I would see the beautiful mountains around me.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
29
+ {"answer_id": "AUWSLTAvGV5sP9qLsVyzjT", "model_id": "llama-13b:v1", "question_id": 29, "text": "I'd imagine that life would be much different on Mars than life on Earth. The red planet is much farther from the Sun, and as such, it's a lot colder. Average temperatures on Mars are about -80 degrees Fahrenheit. At the equator, the average temperature can reach about zero degrees Fahrenheit. The surface of Mars is also very dry. Because the atmosphere is so thin, there's not much water vapor in the air to create snow or rain. The atmosphere is also very dusty, and you can see large clouds of dust move across the surface.\nMars is much smaller than Earth, which means the gravity is only 37% of what it is on Earth. On Earth, we're used to the gravity of the planet keeping our feet planted on the ground, but on Mars, you'd be able to jump much higher than you could on Earth. The atmosphere on Mars is very thin, only about 1% as dense as Earth's. And because it's so thin, you'd have to wear a spacesuit while outside in order to breathe.\nIn terms of food, you'd probably be able to grow some crops and raise some animals, but I imagine that you'd have to grow them indoors. Also, you'd have to grow all of your food. Mars doesn't have any large bodies of water, so you couldn't fish. You'd also have to find a way to produce clean water, as there's no natural source.\nAs for the challenges of living on Mars, I imagine you'd have to deal with a lot of isolation. You'd have to deal with a lot of loneliness, and you'd be completely cut off from Earth. You'd also have to deal with a lot of boredom. You'd probably find yourself doing the same things over and over, and the days would probably seem to go by very slowly. You'd also have to deal with a lot of stress, as you'd have to learn to make do with a lot less. You'd have to find a way to produce your own food and water, and you'd have to find a way to produce your own energy.\nYou'd also have to deal with the cold. The average temperature on Mars is about -80 degrees Fahrenheit. Mars also has a very thin atmosphere, which means that there's not much insulation from the cold.\nI imagine that you'd also have to deal with a lot of dust. The atmosphere is very dusty, and the dust would get everywhere. You'd probably find yourself covered in dust all the time.\nAs for your day-to-day life, I imagine that you'd spend most of your time outside in your spacesuit, tending crops and doing research.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
30
+ {"answer_id": "GZtEdKLxmUUbenH4XRyLkE", "model_id": "llama-13b:v1", "question_id": 30, "text": "Allies are people who are helpful and supportive to us. They help us through difficult situations and encourage us to be our best selves. Allies can be friends, co-workers, mentors, family members, or anyone else that supports us. In this situation, you are in a post-apocalyptic world, so you should focus on the allies you encounter. Here are some ways you can describe the allies you encounter:\n1. Describe the allies you encounter in a post-apocalyptic world.\n2. Describe how you survive in a post-apocalyptic world.\n3. Describe the allies you encounter in a post-apocalyptic world.\n4. Describe what happens to you in a post-apocalyptic world.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
31
+ {"answer_id": "kba2Xpb75rUfnskZSzrBas", "model_id": "llama-13b:v1", "question_id": 31, "text": "The most obvious sign that a restaurant is popular among locals is if the restaurant is always crowded and full of people, while if it is attracting tourists, the restaurant will be less crowded. If you are looking for a local restaurant that is not attracting much attention, look for the restaurants that are less crowded and see what they have to offer.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
32
+ {"answer_id": "RCaptsMBYXseVJgjJyyvUh", "model_id": "llama-13b:v1", "question_id": 32, "text": "Here are some subtle clues that suggest someone is pretending to understand a topic or conversation when they are actually confused or uninformed:\n1. They pretend to understand something better than they actually do.\n2. They repeat what someone else has said, but with an odd emphasis.\n3. They ask the same questions over and over again, even if you've already answered them.\n4. They ask for more information after you've already given them plenty.\n5. They can't explain something they should know.\n6. They don't ask questions that a person who is confused would ask.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
33
+ {"answer_id": "2CeidEr7Cjj6YgcuaiPgnN", "model_id": "llama-13b:v1", "question_id": 33, "text": "Using a GPS device or smartphone app is a good way to navigate when you already have the destination in mind. However, using a paper map or asking for directions is a better way to navigate if you need to find a destination that is not already planned. Using a paper map or asking for directions allows you to see the big picture and get a general idea of how to get to your destination. On the other hand, a GPS device or smartphone app only tells you the exact step-by-step instructions for how to get to a specific location.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
34
+ {"answer_id": "kpStLLNWNkGHyRFdKsmDds", "model_id": "llama-13b:v1", "question_id": 34, "text": "It's difficult to determine if a person is genuinely interested in a conversation or simply being polite. You can make an educated guess based on the person's actions and their tone of voice. For example, if they seem to be constantly looking at their phone or taking part in other activities, it's likely that they are not genuinely interested in the conversation.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
35
+ {"answer_id": "PFtMQp4X4MrMYkEa8c8CKG", "model_id": "llama-13b:v1", "question_id": 35, "text": "When you shop at a local, independently-owned business rather than a chain store, you are supporting the community and keeping your money in the local economy. Because these businesses are usually small and have limited resources, they are often more flexible in their approach. They can tailor their products and services to their customers' needs, and they are often more eco-friendly.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
36
+ {"answer_id": "4JwjhR5QVsdYgQupSZrfUp", "model_id": "llama-13b:v1", "question_id": 36, "text": "There are a number of ways you can verify the credibility of a source of information. First, check to see if the source is from a reliable, well-known publisher. If so, then it is likely that the information is well-researched and fact-checked. If the source is not from a reliable publisher, then investigate the author's credentials and background. Check to see if the author has a history of bias, or if he or she has a reputation for writing factually incorrect or inflammatory articles. Finally, you can also check the source's sources. Are the sources credible and well-researched? Do they provide evidence to support the article's claims? If not, then it is likely that the source itself is not credible.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
37
+ {"answer_id": "ednPMy4dvW9CgKg2PJsBqW", "model_id": "llama-13b:v1", "question_id": 37, "text": "People often enjoy the sensation of being scared, such as by watching horror movies or going on roller coasters, because it can be exciting and exhilarating. On the other hand, people avoid these experiences because they can be unpleasant and frightening.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
38
+ {"answer_id": "ehPApSqCniyGN7hd332ToW", "model_id": "llama-13b:v1", "question_id": 38, "text": "For example, when you are in a new country, you can start by observing the behavior of the people around you. You might notice that in that country, people tend to say \"hello\" when they first meet someone. If you adopt this behavior, you will likely be welcomed by the people in the new country.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
39
+ {"answer_id": "bLvdcqSwgSCrUPZ8eQLLqb", "model_id": "llama-13b:v1", "question_id": 39, "text": "Focusing on space travel could be seen as a distraction from more pressing issues right here on Earth. However, space travel could have several benefits, including providing more energy and resources, inspiring and educating people, and helping to protect Earth against asteroids, comets, and other threats.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
40
+ {"answer_id": "8tx5GEHjSRKZRaZwHvLoM8", "model_id": "llama-13b:v1", "question_id": 40, "text": "Thanks for the question. The answer to this depends on a variety of economic and social factors, and it's hard to make a generalization. However, there are a few things to keep in mind:\n1. Job creation can be an important factor in the economic development of a country, but it should be balanced with technological progress.\n2. Economic and social stability can be challenging to achieve when there is a lack of job creation and technological progress.\n3. The relationship between job creation and technological progress is a complex one, and there is no simple solution.\n4. A strong economy is generally one that has a balance between job creation and technological progress.\n5. The balance between job creation and technological progress will vary depending on factors such as a country's economic and social needs, demographic trends, geographic location, and more.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
41
+ {"answer_id": "J6TzFQAyiHs6Arijqwxtrm", "model_id": "llama-13b:v1", "question_id": 41, "text": "According to some studies, the average human blinks 20,000 to 30,000 times per day. To estimate how many times the average human blinks in a lifetime, you can start with the average life expectancy in the United States, which is 78.8 years. This means that the average American lives 78.8 years, or 29,415 days. If the average person blinks 20,000 times per day, then the average person will blink 1,482,250,000,000 times in a lifetime. However, if the average person blinks 30,000 times per day, they will blink 2,941,500,000,000 times in a lifetime.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
42
+ {"answer_id": "f7xAtd3wwc4vP7WSnHffAN", "model_id": "llama-13b:v1", "question_id": 42, "text": "There are 1022 to 1023 atoms in a grain of salt. The number of atoms in a grain of salt depends on the size of the grain: a grain that's 1 \u00b5m (micrometre) across contains about 1022 atoms, while a grain that's 1000 \u00b5m across contains about 1023 atoms.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
43
+ {"answer_id": "eRMyp4AbfgBqtbVZzHupFN", "model_id": "llama-13b:v1", "question_id": 43, "text": "According to the National Weather Service, there are between 1,000 and 2,000 lightning strikes per day in the U.S. alone. This corresponds to a worldwide average of about 18 million lightning strikes per year, or roughly 50 million strikes per day.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
44
+ {"answer_id": "kjKb6YG86FfAh22qCePV4V", "model_id": "llama-13b:v1", "question_id": 44, "text": "Assuming the house weighs 60,000 pounds, 40,000 balloons would be necessary. The average balloon holds 2 pints of helium, which has a density of 0.17g/L. Therefore, the volume of a balloon is 2pints/0.17g/L = 12L. The weight of 12L of helium is 12*0.17 = 2 pounds. Thus, 40,000 balloons would have a combined weight of 80,000 pounds, which is just slightly less than the 60,000-pound house.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
45
+ {"answer_id": "ULtkGu6m6tB4tawFStVdJM", "model_id": "llama-13b:v1", "question_id": 45, "text": "According to a 2016 study, there were 18.2 billion text messages sent in the US alone each day. Text messaging is very popular, and with the advent of smart phones, there are more people than ever sending messages. Globally, there are 16.3 million text messages sent per minute. In 2016, there were 23.6 trillion text messages sent worldwide, and 8.5 trillion text messages sent in the US alone.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
46
+ {"answer_id": "C6hkCnJZmfCGvnr5an2jFh", "model_id": "llama-13b:v1", "question_id": 46, "text": "From the available data, it seems that around 5,000,000,000,000 words are spoken daily on Earth. I will now explain my reasoning step-by-step.\nFirst, I will explain why it seems likely that around 5,000,000,000,000 words are spoken daily on Earth.\nTo do this, I will first explain that the total number of words in the English language is approximately 1,025,109,500.\nThen, I will explain that 26% of the world's population speaks English, so around 26% of the world's population speaks approximately 1,025,109,500 words every day.\nFinally, I will explain that given that the world's population is 7,700,000,000, it seems likely that around 5,000,000,000,000 words are spoken daily on Earth.\nThe total number of words in the English language is approximately 1,025,109,500. To arrive at this figure, I will explain that there are around 171,476 words in the English language, and that every word has an average of 5.87 definitions.\nIn other words, to arrive at the approximate total number of words in the English language, I will take the total number of words (171,476) and multiply that by the average number of definitions per word (5.87).\nThen, I will explain that 26% of the world's population speaks English. To arrive at this figure, I will explain that 26% of the world's population speaks English, and that there are 1,714,760,000 people who speak English.\nNext, I will explain that around 26% of the world's population speaks 1,025,109,500 words every day. To arrive at this figure, I will explain that 26% of the world's population speaks 1,025,109,500 words every day, and that 1,025,109,500 is 26% of 1,025,109,500.\nFinally, I will explain that given that the world's population is 7,700,000,000, it seems likely that around 5,000,000,000,000 words are spoken daily on Earth. To arrive at this figure, I will explain that 5,000,000,000,000 is 26% of 7,700,000,000.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
47
+ {"answer_id": "DApBTLuiJwSsfQHTioH9PY", "model_id": "llama-13b:v1", "question_id": 47, "text": "A typical winter will see around 200 snowflakes fall, with a diameter between 100 and 200 micrometers. Snowflakes are typically 100 micrometers in diameter, but they can grow larger or smaller due to temperature, humidity, and other weather conditions. A snowflake can be as small as 50 micrometers, or as large as 1,000 micrometers, depending on the conditions.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
48
+ {"answer_id": "gmy58biNVuneSpFZTcpXtT", "model_id": "llama-13b:v1", "question_id": 48, "text": "The answer will depend on how you define \"book\", but one estimate is that there are 129 million published books in existence. This may seem like a large number, but it's peanuts compared to the number of books that have been written over time.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
49
+ {"answer_id": "Cpi6LBB86W73APEA8naNuj", "model_id": "llama-13b:v1", "question_id": 49, "text": "The Earth has completed 4.7 billion orbits since the beginning of life. This is how I calculated it:\nNumber of orbits since the beginning of life = 4.7 billion\nThe Earth revolves around the Sun at 66,600 mph (30 km/s).\n1 year = 365.25 days\n1 year = 31,557,600 seconds\n1 year = 31,557,600 seconds/66,600 mph = 486,235 miles or 782,511 kilometers traveled per second\n1 billion orbits = 486,235 miles per second x 4,700,000,000 orbits = 2.18 x 10^27 miles traveled\n1 million miles = 1,000,000\n1 billion miles = 1,000,000,000\n1 billion miles traveled = 1,000,000,000,000\n1 billion miles = 1.08 x 10^17 miles\n1 billion orbits = 2.18 x 10^27 miles traveled / 1.08 x 10^17 miles = 2.01 x 10^10 orbits\n1 billion orbits = 4.7 billion orbits since the beginning of life", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
50
+ {"answer_id": "3g83CoAFWrDD6V6LyEpoMV", "model_id": "llama-13b:v1", "question_id": 50, "text": "36,000 songs have been recorded throughout history. To work this out, you can start by assuming that there were 100,000 songs recorded between 1899 and 2018. This means that there were 1.87 songs recorded per day on average. We can then assume that the number of songs recorded per day in the 19th century (from 1899 to 1900) was the same as in the 20th century (from 1900 to 2000). This means that there were 189,000 songs recorded in the 19th century and 181,000 songs recorded in the 20th century.\nHence, there were 36,000 songs recorded throughout history.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
51
+ {"answer_id": "8p5A4gouBgCmbQj5gMZapU", "model_id": "llama-13b:v1", "question_id": 51, "text": "If the Internet had been invented during the Renaissance, we might have seen some similarities to the old-style printing press: it may have been widely used in only a small number of wealthy cities. Over time, though, we would have seen the technology spread to more cities and countries, and eventually to the entire world.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
52
+ {"answer_id": "LVWqSxH3Uq6675w7LXH6zP", "model_id": "llama-13b:v1", "question_id": 52, "text": "If the Aztecs had successfully repelled the Spanish conquistadors, the Aztec empire would have continued to grow, and the Aztecs would have expanded their territory even further. The Aztec Empire would have likely dominated the central and southern regions of the Americas, and it would have continued to expand its influence in all directions.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
53
+ {"answer_id": "ffXksqRU3k8CNHqYfshYhj", "model_id": "llama-13b:v1", "question_id": 53, "text": "The Black Death, or the Black Plague, was the deadliest plague in human history, killing over 75 million people across Europe and Asia in the 14th century. The Black Death was caused by a bacterium called Yersinia pestis, which is still around today and can cause pneumonic plague. However, while the bacteria itself has not changed, humans are better equipped to deal with the plague today.\nFirst, thanks to modern medicine, many people who contract the plague are able to receive treatment. Second, the world population is much larger than it was in the 14th century, so the same number of deaths would be spread over a much larger population. As a result, the death toll would be much lower today.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
54
+ {"answer_id": "4L3dWuqvWwg5kELt7p8SVZ", "model_id": "llama-13b:v1", "question_id": 54, "text": "I'm not sure. If he did so, he wouldn't have discovered the law of gravity.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
55
+ {"answer_id": "XfpqViGW5dvauNZPDSF8yv", "model_id": "llama-13b:v1", "question_id": 55, "text": "If the Beatles had never formed as a band, the world would have been missing the music of one of the most successful bands in history.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
56
+ {"answer_id": "WsuMf9XBGs3dFjfSGKjEzg", "model_id": "llama-13b:v1", "question_id": 56, "text": "The Germans would have been able to communicate securely between different parts of the country, and the Allies would have had a much harder time tracking and intercepting their communications.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
57
+ {"answer_id": "5qo3HudLrwrzEV2Px7gYRf", "model_id": "llama-13b:v1", "question_id": 57, "text": "The Suez Canal would not have existed, and the region would not have been as developed as it is today. Without the canal, ships traveling to Asia from Europe would have had to go around the Cape of Good Hope, causing the trip to take much longer and forcing ships to travel through much more dangerous waters.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
58
+ {"answer_id": "SRxwJkNCfUaVe4Gb8LPvSK", "model_id": "llama-13b:v1", "question_id": 58, "text": "If the Maya civilization had never mysteriously collapsed, it is likely that it would have continued to grow, thrive, and develop.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
59
+ {"answer_id": "S54Mmyim3QT5owdhwnNYmB", "model_id": "llama-13b:v1", "question_id": 59, "text": "Christopher Columbus was an Italian explorer who is credited with discovering the Americas in 1492. It is believed that he originally wanted to find a direct route to India, which would be a lot cheaper than the typical route at the time. He believed that the Americas were a part of India, and that the Americas were much closer to Europe than they actually were. When Columbus sailed across the Atlantic Ocean in 1492, he actually landed somewhere in the Bahamas, and not in India. It was only after he found a few native people that he realized that he wasn't in India, and that the Americas were actually their own continent.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
60
+ {"answer_id": "8DJpwEFb5jqt3vMj5oNqVX", "model_id": "llama-13b:v1", "question_id": 60, "text": "There are many reasons why Van Gogh did not have much success during his lifetime. One of the reasons is that he was a very poor businessman. He was very generous and gave away many of his paintings as presents. He also had a tendency to drink too much. He spent some time in a mental asylum and died at the age of 37.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
61
+ {"answer_id": "KMtqqYkDmyyThkvxXJ4kJd", "model_id": "llama-13b:v1", "question_id": 61, "text": "You can use the following steps to solve this problem.\n1. Create an empty C++ program.\n2. Include the <iostream> header to use the std::cin and std::cout functions.\n3. Open the text file using the fstream class.\n4. Read a single line from the file using the getline function.\n5. Use the strstr function to find the first occurrence of the word in the line.\n6. Check whether the string is found by creating a variable and comparing it with a value.\n7. Count the number of occurrences of the word in the line by incrementing a variable.\n8. Close the file.\n9. Print the number of occurrences of the word.\n10. Output the file name to the console.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
62
+ {"answer_id": "MkbgwTpP5b79NWdPEnvHyq", "model_id": "llama-13b:v1", "question_id": 62, "text": "The longest common subsequence problem is that of finding the longest string that occurs in both of two given strings. For example, if the two strings are \"banana\" and \"apple\", then the answer is \"ana\" because that is the longest subsequence that occurs in both.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
63
+ {"answer_id": "PSb6r567shkRasPkJagvWi", "model_id": "llama-13b:v1", "question_id": 63, "text": "Here are two ways to do it:\n1. You can use the function re.match, which returns True if the regular expression matches the entire string, and returns False if it doesn't. So, we can do something like this:\nif re.match('^[a-z0-9]+(\\.[a-z0-9]+)*@[a-z0-9]+(\\.[a-z0-9]+)*(\\.[a-z]{2,4})$', email):\nprint('It is a valid email')\nprint('It is not a valid email')\n2. You can use the function re.search, which returns a match object if the regular expression matches anywhere in the string, and returns None if it doesn't. So, we can do something like this:\nif re.search('^[a-z0-9]+(\\.[a-z0-9]+)*@[a-z0-9]+(\\.[a-z0-9]+)*(\\.[a-z]{2,4})$', email):", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
64
+ {"answer_id": "QWpBq9gM4iYmsTLQ3NNnBe", "model_id": "llama-13b:v1", "question_id": 64, "text": "Fibonacci number is a sequence of numbers where each number is the sum of the two previous numbers. The first two numbers in the sequence are 0 and 1. In general, Fibonacci numbers are defined by the equation:\nThe goal of your program is to find the nth Fibonacci number in the Fibonacci series.\nNote: n will be a positive integer.\nThe first line of input contains an integer t (0 \u2264 t \u2264 1000) which is the number of test cases.\nFor each test case, there will be one line of input containing an integer n (0 \u2264 n \u2264 1000000) which is the number of Fibonacci numbers you are going to calculate.\nPrint each Fibonacci number in a new line.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
65
+ {"answer_id": "U3QxoC4DxXoggnLRbgJh8b", "model_id": "llama-13b:v1", "question_id": 65, "text": "Good to know! I will implement a binary search algorithm to find a specific element in a sorted array.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
66
+ {"answer_id": "f3TXCreQPJ5Erd6BGt4SFE", "model_id": "llama-13b:v1", "question_id": 66, "text": "Implementing a queue using two stacks is straightforward. Take two stacks and create a third stack. First, push the first element of the first stack onto the third stack. Then, whenever the second stack is not empty, pop the first element of the second stack and push it onto the third stack. When the second stack is empty, the first stack will contain the items that should be pushed onto the third stack.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
67
+ {"answer_id": "SG7YcT8s4sJNBRqidKk7Z7", "model_id": "llama-13b:v1", "question_id": 67, "text": "We can use the following approach:\n1. Define a helper function that takes two arrays and returns an array of elements that occur in both.\n2. Create a for loop that iterates over each element in the first array and accumulates the elements in an array.\n3. Call the helper function in each iteration.\n4. Return the accumulated array.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
68
+ {"answer_id": "ENsG3zoeUY7ArFo2kNaQbz", "model_id": "llama-13b:v1", "question_id": 68, "text": "f(2) = 26", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
69
+ {"answer_id": "C7CN5AYJhGC7iDgS7rciDQ", "model_id": "llama-13b:v1", "question_id": 69, "text": "3x + 10 = 5x - 10x + 20 = 20 + 10x = 30x = 10.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
70
+ {"answer_id": "QP9CmZKwhu2BrUmzhy8tH4", "model_id": "llama-13b:v1", "question_id": 70, "text": "The length of the line segment is 10 units.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
71
+ {"answer_id": "2KtpJeRqEfbPLwdUxAXWrj", "model_id": "llama-13b:v1", "question_id": 71, "text": "Here are some tips for writing a formal email to a potential business partner proposing a joint venture:\n1. Be clear and concise\n2. Use formal language\n3. Use a proper greeting\n4. Use a proper closing\n5. Use proper punctuation\n6. Be sure to include all relevant information", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
72
+ {"answer_id": "gsL3R2grTXUiHdeDSk8f9Q", "model_id": "llama-13b:v1", "question_id": 72, "text": "I can help you write a resignation letter to your current employer, while leaving on good terms and expressing gratitude for the opportunities provided. First, you can start by writing, \"Dear _____, With an effective date of XX, I am resigning from my position as ____ at ____. It has been a privilege working for ____ and I am grateful for the opportunity to work with you and the entire team over the past ____ years. Thank you for your support and guidance during my time with _____.\"", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
73
+ {"answer_id": "HprPMoUk9JY5TtToRbuMwy", "model_id": "llama-13b:v1", "question_id": 73, "text": "The most common format for a letter of recommendation is as follows:\n1. Salutation: In the salutation, address the recipient as Mr., Mrs., Ms., or Dr.\n2. Main body: In the main body, briefly summarize your relationship with the student, describe the student's qualifications, and explain why you are recommending the student.\n3. Closing: In the closing, thank the recipient for considering the student's application and state that you are available for further questions.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
74
+ {"answer_id": "YU7gDhmo4LDVMTEZZMRdBC", "model_id": "llama-13b:v1", "question_id": 74, "text": "", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
75
+ {"answer_id": "2gNqU2XpsL6yqFn6DMRgHq", "model_id": "llama-13b:v1", "question_id": 75, "text": "Hello, I'm ready to draft the apology email for you. Please tell me the name of the customer you'd like to send this email to.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
76
+ {"answer_id": "a4L6JNPuomkPLdWxJRQpod", "model_id": "llama-13b:v1", "question_id": 76, "text": "Jazz is a form of music that originated in New Orleans in the late 19th century. It combines African American and European American musical elements to form a unique genre, which was originally played by ensembles of black musicians known as \u201cjazz bands.\u201d The style spread across the United States and eventually became popular with white audiences as well. Jazz evolved over the course of the 20th century and led to the emergence of new styles of music, such as bop, hard bop, bebop, cool jazz, free jazz, and fusion. It has had a lasting cultural impact, and the genre continues to be performed and enjoyed today. Here are some key facts about jazz:\n1. Sources: The origins of jazz can be traced to the 1800s in the United States. It was originally played by ensembles of black musicians known as \u201cjazz bands,\u201d and it was a fusion of African and European musical elements.\n2. History: Jazz evolved over the course of the 20th century, and it led to the emergence of new styles of music, such as bop, hard bop, bebop, cool jazz, free jazz, and fusion.\n3. Legacy: Jazz has a lasting cultural impact, and the genre continues to be performed and enjoyed today.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
77
+ {"answer_id": "eN9V2Wf2BNzWoJosqdFw6K", "model_id": "llama-13b:v1", "question_id": 77, "text": "Here is a sample travel blog post about a recent trip to Hawaii:", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
78
+ {"answer_id": "iCzkbWTatKfuC6QVv59XgA", "model_id": "llama-13b:v1", "question_id": 78, "text": "In the year 2045, virtual reality is the norm. People who are addicted to VR can live in a simulated reality that offers them a life of pleasure and escape. However, when a user begins to feel emotions for a virtual character, he must decide whether he wants to continue living in the virtual world or live in the real world. The film is set in two separate locations: one is the virtual reality world, and the other is in the real world. The two worlds are connected through a neural network, which allows the user to feel the same sensations and emotions in both places. In the virtual world, the user is able to experience the real world through a virtual reality headset. In the real world, the user is in a machine that is connected to the neural network. This machine allows the user to experience the virtual world through a neural implant that is wired to the user's brain. The film's plot centers around the user and the virtual character. The characters in the virtual world are all artificial intelligence entities, and the user's experience in the real world is with two characters, his friend and his girlfriend. The user is a young man in the film, and the virtual character is a young woman. The two characters have a virtual relationship, and then the virtual character is killed by a bug in the system. This causes the user to fall into a deep depression, and he begins to question his own reality. He decides to go to the real world to get treatment for his depression. In the real world, the user has to decide whether he wants to stay in the real world or return to the virtual world. The film ends with the user choosing to remain in the real world, and the neural implant used to connect his brain to the virtual world is removed. The movie is based on the novel \"Ready Player One\" by Ernest Cline. The film is directed by Steven Spielberg and stars Tye Sheridan as the main actor. The film is a science fiction adventure and a romantic drama. The film's special effects are excellent, and the visuals are beautiful. The film's soundtrack is also quite good. The film has received positive reviews from critics, and it has been nominated for a number of awards, including the Academy Award for Best Visual Effects. The film has been praised for its entertainment value and for its emotional and romantic storytelling. The film has been compared to the popular sci-fi film \"The Matrix.\" Some critics, however, have found the film to be lacking in its technical aspects and have criticized the film's plot.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
79
+ {"answer_id": "GQcFyY8ayFT48DAP5NKNZf", "model_id": "llama-13b:v1", "question_id": 79, "text": "I found 4,610 relevant search results for this topic. The results include:\n1. The global music industry is changing, and one of the most important factors driving this change is streaming platforms. Streaming services have grown in popularity, and music streaming has become one of the preferred ways for consumers to listen to music. What exactly is streaming?\n2. Streaming means that music is delivered over the internet. Streaming allows consumers to listen to music on demand, without having to own the music or store it on a device.\n3. Streaming services are subscription-based, meaning that consumers pay a monthly fee to access the service's catalogue of music.\n4. Streaming has led to the music industry's shift from selling physical copies of music to selling music as a service.\n5. Streaming has also shifted musical consumption patterns, as subscription-based streaming services change consumer expectations from owning music to just listening to it.\n6. Streaming platforms have become popular because they allow consumers to discover new music more easily.\n7. Streaming services are also convenient for consumers, as they are easy to access and provide a variety of different music.\n8. Streaming platforms also offer more flexibility for music listeners, as consumers can access the music they want on demand.\n9. Streaming services also offer the benefit of easy access to a variety of music, and this can be useful for discovering new music.\n10. Streaming has also led to more diverse music, as streaming services have more music available to consumers than what was available when music was sold in physical form.\n11. Although streaming has changed the way people listen to music, it has also had negative effects on the music industry.\n12. Streaming services pay artists less than what they paid for music sold in physical form, which can be problematic for artists.\n13. Streaming services also pay less to the music industry overall.\n14. Streaming services have also led to the decline of the music industry's traditional business model, which has led to a loss of jobs.\n15. Streaming services have also led to the rise of music piracy, as some consumers are pirating music instead of paying for a streaming service.\n16. Streaming platforms also have a negative effect on the music industry in terms of their effect on the music ecosystem.\n17. Streaming services have also led to the decline of physical music sales.\n18. Streaming platforms have also led to the decline of record stores, which can harm artists and consumers.\n19. Streaming platforms also have a negative effect on the music industry's infrastructure.\n20. Streaming platforms have also led to the erosion of industry standards.\n21. Streaming platforms have also led to a decline in music videos and music videos being less accessible.\n22. Streaming platforms have also led to the decline of radio, which can be problematic for listeners.\n23. Streaming platforms have also led to a decline in music journalism.\n24. Streaming platforms have also led to a decline in the number of record labels.\n25. Streaming platforms have also led to a decline in the number of music producers.\n26. Streaming platforms have also led to a decline in the music industry's contribution to the economy.\n27. Streaming platforms have also led to the decline of music festivals.\n28. Streaming platforms have also led to the decline of music labels.\n29. Streaming platforms have also led to a decline in the music industry's culture.\n30. Streaming platforms have also led to a decline in the music industry's diversity.\n31. Streaming platforms have also led to the decline of the music industry's talent development.\n32. Streaming platforms have also led to a decline in the music industry's competitiveness.\n33. Streaming platforms have also led to the decline of the music industry's social and economic effects.\n34. Streaming platforms have also led to the decline of the music industry's intellectual property.\n35. Streaming platforms have also led to the decline of the music industry's political and social engagement.\n36. Streaming platforms have also led to the decline of the music industry's cultural and artistic output.\n37. Streaming platforms have also led to the decline of the music industry's public value.\n38. Although streaming has had a negative effect on the music industry, it has also had positive effects.\n39. Streaming platforms have", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
80
+ {"answer_id": "A6pauLMckn34otnfpeQRFi", "model_id": "llama-13b:v1", "question_id": 80, "text": "The orchestra performed well, and the audience was receptive to the performance. The orchestra played a selection of pieces from some of the most well-known classical composers, including Beethoven, Mozart, and Tchaikovsky. The orchestra played in unison with each other, and the conductor led them through the pieces without much difficulty. The orchestra played with a lot of energy and spirit. The audience was appreciative and applauded at the appropriate times. The performance lasted 90 minutes, and it was well worth the price of admission. I would recommend the concert to anyone who wants to experience beautiful classical music.", "metadata": {"huggingface_argument": {"do_sample": true, "temperature": 0.7, "max_new_tokens": 1024}}}
answer_vicuna-13b.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import hashlib
3
+ import json
4
+ import os
5
+ import time
6
+ from threading import Thread
7
+ import logging
8
+ import gradio as gr
9
+ import torch
10
+
11
+ from tinyllava.model.builder import load_pretrained_model
12
+ from tinyllava.mm_utils import (
13
+ KeywordsStoppingCriteria,
14
+ load_image_from_base64,
15
+ process_images,
16
+ tokenizer_image_token,
17
+ get_model_name_from_path,
18
+ )
19
+ from PIL import Image
20
+ from io import BytesIO
21
+ import base64
22
+ import torch
23
+ from transformers import StoppingCriteria
24
+ from tinyllava.constants import (
25
+ DEFAULT_IM_END_TOKEN,
26
+ DEFAULT_IM_START_TOKEN,
27
+ DEFAULT_IMAGE_TOKEN,
28
+ IMAGE_TOKEN_INDEX,
29
+ )
30
+ from tinyllava.conversation import SeparatorStyle, conv_templates, default_conversation
31
+
32
+ from transformers import TextIteratorStreamer
33
+ from pathlib import Path
34
+
35
+ DEFAULT_MODEL_PATH = "bczhou/TinyLLaVA-3.1B"
36
+ DEFAULT_MODEL_NAME = "TinyLLaVA-3.1B"
37
+
38
+
39
+ block_css = """
40
+
41
+ #buttons button {
42
+ min-width: min(120px,100%);
43
+ }
44
+ """
45
+ title_markdown = """
46
+ # TinyLLaVA: A Framework of Small-scale Large Multimodal Models
47
+ [[Code](https://github.com/DLCV-BUAA/TinyLLaVABench)] | 📚 [[Paper](https://arxiv.org/pdf/2402.14289.pdf)]
48
+ """
49
+ tos_markdown = """
50
+ ### Terms of use
51
+ By using this service, users are required to agree to the following terms:
52
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
53
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
54
+ """
55
+ learn_more_markdown = """
56
+ ### License
57
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
58
+ """
59
+ ack_markdown = """
60
+ ### Acknowledgement
61
+ The template for this web demo is from [LLaVA](https://github.com/haotian-liu/LLaVA), and we are very grateful to LLaVA for their open source contributions to the community!
62
+ """
63
+
64
+
65
+ def regenerate(state, image_process_mode):
66
+ state.messages[-1][-1] = None
67
+ prev_human_msg = state.messages[-2]
68
+ if type(prev_human_msg[1]) in (tuple, list):
69
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
70
+ state.skip_next = False
71
+ return (state, state.to_gradio_chatbot(), "", None)
72
+
73
+
74
+ def clear_history():
75
+ state = default_conversation.copy()
76
+ return (state, state.to_gradio_chatbot(), "", None)
77
+
78
+
79
+ def add_text(state, text, image, image_process_mode):
80
+ if len(text) <= 0 and image is None:
81
+ state.skip_next = True
82
+ return (state, state.to_gradio_chatbot(), "", None)
83
+
84
+ text = text[:1536] # Hard cut-off
85
+ if image is not None:
86
+ text = text[:1200] # Hard cut-off for images
87
+ if "<image>" not in text:
88
+ # text = '<Image><image></Image>' + text
89
+ text = text + "\n<image>"
90
+ text = (text, image, image_process_mode)
91
+ if len(state.get_images(return_pil=True)) > 0:
92
+ state = default_conversation.copy()
93
+ state.append_message(state.roles[0], text)
94
+ state.append_message(state.roles[1], None)
95
+ state.skip_next = False
96
+ return (state, state.to_gradio_chatbot(), "", None)
97
+
98
+
99
+ def load_demo():
100
+ state = default_conversation.copy()
101
+ return state
102
+
103
+
104
+ @torch.inference_mode()
105
+ def get_response(params):
106
+ prompt = params["prompt"]
107
+ ori_prompt = prompt
108
+ images = params.get("images", None)
109
+ num_image_tokens = 0
110
+ if images is not None and len(images) > 0:
111
+ if len(images) > 0:
112
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
113
+ raise ValueError(
114
+ "Number of images does not match number of <image> tokens in prompt"
115
+ )
116
+
117
+ images = [load_image_from_base64(image) for image in images]
118
+ images = process_images(images, image_processor, model.config)
119
+
120
+ if type(images) is list:
121
+ images = [
122
+ image.to(model.device, dtype=torch.float16) for image in images
123
+ ]
124
+ else:
125
+ images = images.to(model.device, dtype=torch.float16)
126
+
127
+ replace_token = DEFAULT_IMAGE_TOKEN
128
+ if getattr(model.config, "mm_use_im_start_end", False):
129
+ replace_token = (
130
+ DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
131
+ )
132
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
133
+
134
+ num_image_tokens = (
135
+ prompt.count(replace_token) * model.get_vision_tower().num_patches
136
+ )
137
+ else:
138
+ images = None
139
+ image_args = {"images": images}
140
+ else:
141
+ images = None
142
+ image_args = {}
143
+
144
+ temperature = float(params.get("temperature", 1.0))
145
+ top_p = float(params.get("top_p", 1.0))
146
+ max_context_length = getattr(model.config, "max_position_embeddings", 2048)
147
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
148
+ stop_str = params.get("stop", None)
149
+ do_sample = True if temperature > 0.001 else False
150
+ logger.info(prompt)
151
+ input_ids = (
152
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
153
+ .unsqueeze(0)
154
+ .to(model.device)
155
+ )
156
+ keywords = [stop_str]
157
+
158
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
159
+ streamer = TextIteratorStreamer(
160
+ tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
161
+ )
162
+
163
+ max_new_tokens = min(
164
+ max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens
165
+ )
166
+
167
+ if max_new_tokens < 1:
168
+ yield json.dumps(
169
+ {
170
+ "text": ori_prompt
171
+ + "Exceeds max token length. Please start a new conversation, thanks.",
172
+ "error_code": 0,
173
+ }
174
+ ).encode() + b"\0"
175
+ return
176
+
177
+ # local inference
178
+ # BUG: If stopping_criteria is set, an error occur:
179
+ # RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0
180
+ generate_kwargs = dict(
181
+ inputs=input_ids,
182
+ do_sample=do_sample,
183
+ temperature=temperature,
184
+ top_p=top_p,
185
+ max_new_tokens=max_new_tokens,
186
+ streamer=streamer,
187
+ # stopping_criteria=[stopping_criteria],
188
+ use_cache=True,
189
+ **image_args,
190
+ )
191
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
192
+ thread.start()
193
+ logger.debug(ori_prompt)
194
+ logger.debug(generate_kwargs)
195
+ generated_text = ori_prompt
196
+ for new_text in streamer:
197
+ generated_text += new_text
198
+ if generated_text.endswith(stop_str):
199
+ generated_text = generated_text[: -len(stop_str)]
200
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode()
201
+
202
+
203
+ def http_bot(state, temperature, top_p, max_new_tokens):
204
+ if state.skip_next:
205
+ # This generate call is skipped due to invalid inputs
206
+ yield (state, state.to_gradio_chatbot())
207
+ return
208
+
209
+ if len(state.messages) == state.offset + 2:
210
+ # First round of conversation
211
+
212
+ if "tinyllava" in model_name.lower():
213
+ if "3.1b" in model_name.lower() or "phi" in model_name.lower():
214
+ template_name = "phi"
215
+ elif "2.0b" in model_name.lower() or "stablelm" in model_name.lower():
216
+ template_name = "phi"
217
+ elif "qwen" in model_name.lower():
218
+ template_name = "qwen"
219
+ else:
220
+ template_name = "v1"
221
+
222
+ elif "llava" in model_name.lower():
223
+
224
+ if "llama-2" in model_name.lower():
225
+ template_name = "llava_llama_2"
226
+ elif "v1" in model_name.lower():
227
+ if "mmtag" in model_name.lower():
228
+ template_name = "v1_mmtag"
229
+ elif (
230
+ "plain" in model_name.lower()
231
+ and "finetune" not in model_name.lower()
232
+ ):
233
+ template_name = "v1_mmtag"
234
+ else:
235
+ template_name = "llava_v1"
236
+ elif "mpt" in model_name.lower():
237
+ template_name = "mpt"
238
+ else:
239
+ if "mmtag" in model_name.lower():
240
+ template_name = "v0_mmtag"
241
+ elif (
242
+ "plain" in model_name.lower()
243
+ and "finetune" not in model_name.lower()
244
+ ):
245
+ template_name = "v0_mmtag"
246
+ else:
247
+ template_name = "llava_v0"
248
+ elif "mpt" in model_name:
249
+ template_name = "mpt_text"
250
+ elif "llama-2" in model_name:
251
+ template_name = "llama_2"
252
+ else:
253
+ template_name = "vicuna_v1"
254
+ new_state = conv_templates[template_name].copy()
255
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
256
+ new_state.append_message(new_state.roles[1], None)
257
+ state = new_state
258
+
259
+ # Construct prompt
260
+ prompt = state.get_prompt()
261
+
262
+ all_images = state.get_images(return_pil=True)
263
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
264
+
265
+ # Make requests
266
+ # pload = {"model": model_name, "prompt": prompt, "temperature": float(temperature), "top_p": float(top_p),
267
+ # "max_new_tokens": min(int(max_new_tokens), 1536), "stop": (
268
+ # state.sep
269
+ # if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
270
+ # else state.sep2
271
+ # ), "images": state.get_images()}
272
+
273
+ pload = {
274
+ "model": model_name,
275
+ "prompt": prompt,
276
+ "temperature": float(temperature),
277
+ "top_p": float(top_p),
278
+ "max_new_tokens": min(int(max_new_tokens), 1536),
279
+ "stop": (
280
+ state.sep
281
+ if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
282
+ else state.sep2
283
+ ), "images": state.get_images()}
284
+
285
+ state.messages[-1][-1] = "▌"
286
+ yield (state, state.to_gradio_chatbot())
287
+
288
+ # for stream
289
+ output = get_response(pload)
290
+ for chunk in output:
291
+ if chunk:
292
+ data = json.loads(chunk.decode())
293
+ if data["error_code"] == 0:
294
+ output = data["text"][len(prompt) :].strip()
295
+ state.messages[-1][-1] = output + "▌"
296
+ yield (state, state.to_gradio_chatbot())
297
+ else:
298
+ output = data["text"] + f" (error_code: {data['error_code']})"
299
+ state.messages[-1][-1] = output
300
+ yield (state, state.to_gradio_chatbot())
301
+ return
302
+ time.sleep(0.03)
303
+
304
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
305
+ yield (state, state.to_gradio_chatbot())
306
+
307
+
308
+ def build_demo():
309
+ textbox = gr.Textbox(
310
+ show_label=False, placeholder="Enter text and press ENTER", container=False
311
+ )
312
+ with gr.Blocks(title="TinyLLaVA", theme=gr.themes.Default(), css=block_css) as demo:
313
+ state = gr.State()
314
+ gr.Markdown(title_markdown)
315
+
316
+ with gr.Row():
317
+ with gr.Column(scale=5):
318
+ with gr.Row(elem_id="Model ID"):
319
+ gr.Dropdown(
320
+ choices=[DEFAULT_MODEL_NAME],
321
+ value=DEFAULT_MODEL_NAME,
322
+ interactive=True,
323
+ label="Model ID",
324
+ container=False,
325
+ )
326
+ imagebox = gr.Image(type="pil")
327
+ image_process_mode = gr.Radio(
328
+ ["Crop", "Resize", "Pad", "Default"],
329
+ value="Default",
330
+ label="Preprocess for non-square image",
331
+ visible=False,
332
+ )
333
+
334
+ # cur_dir = os.path.dirname(os.path.abspath(__file__))
335
+ cur_dir = Path(__file__).parent
336
+ gr.Examples(
337
+ examples=[
338
+ [
339
+ f"{cur_dir}/examples/extreme_ironing.jpg",
340
+ "What is unusual about this image?",
341
+ ],
342
+ [
343
+ f"{cur_dir}/examples/waterview.jpg",
344
+ "What are the things I should be cautious about when I visit here?",
345
+ ],
346
+ ],
347
+ inputs=[imagebox, textbox],
348
+ )
349
+
350
+ with gr.Accordion("Parameters", open=False) as _:
351
+ temperature = gr.Slider(
352
+ minimum=0.0,
353
+ maximum=1.0,
354
+ value=0.2,
355
+ step=0.1,
356
+ interactive=True,
357
+ label="Temperature",
358
+ )
359
+ top_p = gr.Slider(
360
+ minimum=0.0,
361
+ maximum=1.0,
362
+ value=0.7,
363
+ step=0.1,
364
+ interactive=True,
365
+ label="Top P",
366
+ )
367
+ max_output_tokens = gr.Slider(
368
+ minimum=0,
369
+ maximum=1024,
370
+ value=512,
371
+ step=64,
372
+ interactive=True,
373
+ label="Max output tokens",
374
+ )
375
+
376
+ with gr.Column(scale=8):
377
+ chatbot = gr.Chatbot(elem_id="chatbot", label="Chatbot", height=550)
378
+ with gr.Row():
379
+ with gr.Column(scale=8):
380
+ textbox.render()
381
+ with gr.Column(scale=1, min_width=50):
382
+ submit_btn = gr.Button(value="Send", variant="primary")
383
+ with gr.Row(elem_id="buttons") as _:
384
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
385
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=True)
386
+
387
+ gr.Markdown(tos_markdown)
388
+ gr.Markdown(learn_more_markdown)
389
+ gr.Markdown(ack_markdown)
390
+
391
+ regenerate_btn.click(
392
+ regenerate,
393
+ [state, image_process_mode],
394
+ [state, chatbot, textbox, imagebox],
395
+ queue=False,
396
+ ).then(
397
+ http_bot, [state, temperature, top_p, max_output_tokens], [state, chatbot]
398
+ )
399
+
400
+ clear_btn.click(
401
+ clear_history, None, [state, chatbot, textbox, imagebox], queue=False
402
+ )
403
+
404
+ textbox.submit(
405
+ add_text,
406
+ [state, textbox, imagebox, image_process_mode],
407
+ [state, chatbot, textbox, imagebox],
408
+ queue=False,
409
+ ).then(
410
+ http_bot, [state, temperature, top_p, max_output_tokens], [state, chatbot]
411
+ )
412
+
413
+ submit_btn.click(
414
+ add_text,
415
+ [state, textbox, imagebox, image_process_mode],
416
+ [state, chatbot, textbox, imagebox],
417
+ queue=False,
418
+ ).then(
419
+ http_bot, [state, temperature, top_p, max_output_tokens], [state, chatbot]
420
+ )
421
+
422
+ demo.load(load_demo, None, [state], queue=False)
423
+ return demo
424
+
425
+
426
+ def parse_args():
427
+ parser = argparse.ArgumentParser()
428
+ parser.add_argument("--host", type=str, default=None)
429
+ parser.add_argument("--port", type=int, default=None)
430
+ parser.add_argument("--share", default=None)
431
+ parser.add_argument("--model-path", type=str, default=DEFAULT_MODEL_PATH)
432
+ parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL_NAME)
433
+ parser.add_argument("--load-8bit", action="store_true")
434
+ parser.add_argument("--load-4bit", action="store_true")
435
+ args = parser.parse_args()
436
+ return args
437
+
438
+
439
+ if __name__ == "__main__":
440
+ logging.basicConfig(
441
+ level=logging.INFO,
442
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
443
+ )
444
+ logger = logging.getLogger(__name__)
445
+ logger.info(gr.__version__)
446
+ args = parse_args()
447
+ model_name = args.model_name
448
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
449
+ model_path=args.model_path,
450
+ model_base=None,
451
+ model_name=args.model_name,
452
+ load_4bit=args.load_4bit,
453
+ load_8bit=args.load_8bit
454
+ )
455
+
456
+ demo = build_demo()
457
+ demo.queue()
458
+ demo.launch(server_name=args.host, server_port=args.port, share=args.share)
arguments.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, Optional, Sequence, List
3
+
4
+ import transformers
5
+
6
+ @dataclass
7
+ class ModelArguments:
8
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
9
+ version: Optional[str] = field(default="v0")
10
+ freeze_backbone: bool = field(default=False)
11
+ tune_mm_mlp_adapter: bool = field(default=False)
12
+ vision_tower: Optional[str] = field(default=None)
13
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
14
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
15
+ mm_projector_type: Optional[str] = field(default='linear')
16
+ mm_use_im_start_end: bool = field(default=False)
17
+ mm_use_im_patch_token: bool = field(default=True)
18
+ mm_patch_merge_type: Optional[str] = field(default='flat')
19
+ mm_vision_select_feature: Optional[str] = field(default="patch")
20
+ resampler_hidden_size: Optional[int] = field(default=768)
21
+ num_queries: Optional[int] = field(default=128)
22
+ num_resampler_layers: Optional[int] = field(default=3)
23
+ tune_vision_tower: bool = field(default=False)
24
+ tune_entire_model: bool = field(default=False)
25
+ tune_vit_from_layer: Optional[int] = field(default=100)
26
+ tune_embed_tokens: Optional[int] = field(default=False)
27
+
28
+
29
+ @dataclass
30
+ class DataArguments:
31
+ data_path: str = field(default=None,
32
+ metadata={"help": "Path to the training data."})
33
+ lazy_preprocess: bool = False
34
+ is_multimodal: bool = False
35
+ image_folder: Optional[str] = field(default=None)
36
+ image_aspect_ratio: str = 'square'
37
+
38
+
39
+ @dataclass
40
+ class TrainingArguments(transformers.TrainingArguments):
41
+ cache_dir: Optional[str] = field(default=None)
42
+ optim: str = field(default="adamw_torch")
43
+ remove_unused_columns: bool = field(default=False)
44
+ freeze_mm_mlp_adapter: bool = field(default=False)
45
+ mpt_attn_impl: Optional[str] = field(default="triton")
46
+ model_max_length: int = field(
47
+ default=512,
48
+ metadata={
49
+ "help":
50
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
51
+ },
52
+ )
53
+ double_quant: bool = field(
54
+ default=True,
55
+ metadata={"help": "Compress the quantization statistics through double quantization."}
56
+ )
57
+ quant_type: str = field(
58
+ default="nf4",
59
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
60
+ )
61
+ bits: int = field(
62
+ default=16,
63
+ metadata={"help": "How many bits to use."}
64
+ )
65
+ lora_enable: bool = False
66
+ lora_r: int = 64
67
+ lora_alpha: int = 16
68
+ lora_dropout: float = 0.05
69
+ lora_weight_path: str = ""
70
+ lora_bias: str = "none"
71
+ mm_projector_lr: Optional[float] = None
72
+ group_by_modality_length: bool = field(default=False)
73
+ vision_tower_lr: Optional[float] = None
bard.jpg ADDED
builder.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Baichuan Zhou , Junlong Jia, Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from tinyllava.model import *
23
+ from tinyllava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+
25
+
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto",
27
+ device="cuda", **kwargs):
28
+ kwargs = {"device_map": device_map, **kwargs}
29
+
30
+ if device != "cuda":
31
+ kwargs['device_map'] = {"": device}
32
+
33
+ if load_8bit:
34
+ kwargs['load_in_8bit'] = True
35
+ elif load_4bit:
36
+ kwargs['load_in_4bit'] = True
37
+ kwargs['quantization_config'] = BitsAndBytesConfig(
38
+ load_in_4bit=True,
39
+ bnb_4bit_compute_dtype=torch.float16,
40
+ bnb_4bit_use_double_quant=True,
41
+ bnb_4bit_quant_type='nf4'
42
+ )
43
+ else:
44
+ kwargs['torch_dtype'] = torch.float16
45
+
46
+ # if 'tinyllava' in model_name.lower():
47
+ # Load LLaVA model
48
+ if 'lora' in model_name.lower() and model_base is None:
49
+ warnings.warn(
50
+ 'There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
51
+ if 'lora' in model_name.lower() and model_base is not None:
52
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
53
+
54
+ print('Loading LLaVA from base model...')
55
+ if 'phi' in model_name.lower() or '3.1b' in model_name.lower():
56
+ tokenizer = AutoTokenizer.from_pretrained(model_base, padding_side="right")
57
+ model = TinyLlavaPhiForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
58
+ config=lora_cfg_pretrained, **kwargs)
59
+ elif 'stablelm' in model_name.lower() or '2b' in model_name.lower():
60
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, padding_side="right")
61
+ model = TinyLlavaStablelmForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
62
+ config=lora_cfg_pretrained, **kwargs)
63
+ elif 'qwen' in model_name.lower():
64
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, padding_side="right")
65
+ model = TinyLlavaQwen2ForCausalLM.from_pretrained(model_base, ow_cpu_mem_usage=True,
66
+ config=lora_cfg_pretrained, **kwargs)
67
+ else:
68
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, padding_side="right")
69
+ model = TinyLlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
70
+ config=lora_cfg_pretrained, **kwargs)
71
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
72
+ if model.lm_head.weight.shape[0] != token_num:
73
+ model.lm_head.weight = torch.nn.Parameter(
74
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
75
+ model.model.embed_tokens.weight = torch.nn.Parameter(
76
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
77
+
78
+ print('Loading additional LLaVA weights...')
79
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
80
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
81
+ else:
82
+ # this is probably from HF Hub
83
+ from huggingface_hub import hf_hub_download
84
+ def load_from_hf(repo_id, filename, subfolder=None):
85
+ cache_file = hf_hub_download(
86
+ repo_id=repo_id,
87
+ filename=filename,
88
+ subfolder=subfolder)
89
+ return torch.load(cache_file, map_location='cpu')
90
+
91
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
92
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in
93
+ non_lora_trainables.items()}
94
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
95
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
96
+ model.load_state_dict(non_lora_trainables, strict=False)
97
+
98
+ from peft import PeftModel
99
+ print('Loading LoRA weights...')
100
+ model = PeftModel.from_pretrained(model, model_path)
101
+ print('Merging LoRA weights...')
102
+ model = model.merge_and_unload()
103
+ print('Model is loaded...')
104
+ elif model_base is not None:
105
+ # this may be mm projector only
106
+ print('Loading LLaVA from base model...')
107
+
108
+ if 'phi' in model_name.lower() or '3.1b' in model_name.lower():
109
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, padding_side="right")
110
+ cfg_pretrained = TinyLlavaPhiConfig.from_pretrained(model_path)
111
+ model = TinyLlavaPhiForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained,
112
+ **kwargs)
113
+ elif 'stablelm' in model_name.lower() or '2b' in model_name.lower():
114
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
115
+ cfg_pretrained = TinyLlavaStablelmConfig.from_pretrained(model_path)
116
+ model = TinyLlavaStablelmForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
117
+ config=cfg_pretrained, **kwargs)
118
+ elif 'qwen' in model_name.lower():
119
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, padding_side="right")
120
+ cfg_pretrained = TinyLlavaQwen2Config.from_pretrained(model_path)
121
+ model = TinyLlavaQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained,
122
+ **kwargs)
123
+ else:
124
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
125
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
126
+ model = TinyLlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained,
127
+ **kwargs)
128
+
129
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
130
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
131
+ model.load_state_dict(mm_projector_weights, strict=False)
132
+ else:
133
+ if 'phi' in model_name.lower() or '3.1b' in model_name.lower():
134
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side="right")
135
+ model = TinyLlavaPhiForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
136
+ elif 'stablelm' in model_name.lower() or '2.0b' in model_name.lower():
137
+ from tinyllava.model.language_model.stablelm.tokenization_arcade100k import Arcade100kTokenizer
138
+ tokenizer = Arcade100kTokenizer.from_pretrained(model_path, use_fast=False, padding_side="right")
139
+ model = TinyLlavaStablelmForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
140
+ elif 'qwen' in model_name.lower():
141
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side="right")
142
+ model = TinyLlavaQwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
143
+ else:
144
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
145
+ model = TinyLlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
146
+
147
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
148
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
149
+ if mm_use_im_patch_token:
150
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
151
+ if mm_use_im_start_end:
152
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
153
+ model.resize_token_embeddings(len(tokenizer))
154
+ vision_tower = model.get_vision_tower()
155
+ if not vision_tower.is_loaded:
156
+ vision_tower.load_model()
157
+
158
+ if device != "auto":
159
+ vision_tower.to(device=device, dtype=torch.float16)
160
+
161
+ image_processor = vision_tower.image_processor
162
+
163
+ if hasattr(model.config, "max_sequence_length"):
164
+ context_len = model.config.max_sequence_length
165
+ else:
166
+ context_len = 2048
167
+
168
+ return tokenizer, model, image_processor, context_len
cache_utils.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+
5
+
6
+ class Cache:
7
+ """
8
+ Base, abstract class for all caches. The actual data structure is specific to each subclass.
9
+ """
10
+
11
+ def update(
12
+ self,
13
+ key_states: torch.Tensor,
14
+ value_states: torch.Tensor,
15
+ layer_idx: int,
16
+ cache_kwargs: Optional[Dict[str, Any]] = None,
17
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
18
+ """
19
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
20
+
21
+ Parameters:
22
+ key_states (`torch.Tensor`):
23
+ The new key states to cache.
24
+ value_states (`torch.Tensor`):
25
+ The new value states to cache.
26
+ layer_idx (`int`):
27
+ The index of the layer to cache the states for.
28
+ cache_kwargs (`Dict[str, Any]`, `optional`):
29
+ Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
30
+ cache to be created.
31
+
32
+ Return:
33
+ A tuple containing the updated key and value states.
34
+ """
35
+ raise NotImplementedError("Make sure to implement `update` in a subclass.")
36
+
37
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
38
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
39
+ raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
40
+
41
+ def get_max_length(self) -> Optional[int]:
42
+ """Returns the maximum sequence length of the cached states, if there is any."""
43
+ raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
44
+
45
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
46
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
47
+ # Cache without size limit -> all cache is usable
48
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
49
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
50
+ max_length = self.get_max_length()
51
+ previous_seq_length = self.get_seq_length(layer_idx)
52
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
53
+ return max_length - new_seq_length
54
+ return previous_seq_length
55
+
56
+
57
+ class DynamicCache(Cache):
58
+ """
59
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
60
+
61
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
62
+ `[batch_size, num_heads, seq_len, head_dim]`.
63
+ """
64
+
65
+ def __init__(self) -> None:
66
+ self.key_cache: List[torch.Tensor] = []
67
+ self.value_cache: List[torch.Tensor] = []
68
+ self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
69
+
70
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
71
+ """
72
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
73
+ sequence length.
74
+ """
75
+ if layer_idx < len(self):
76
+ return (self.key_cache[layer_idx], self.value_cache[layer_idx])
77
+ else:
78
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
79
+
80
+ def __iter__(self):
81
+ """
82
+ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
83
+ keys and values
84
+ """
85
+ for layer_idx in range(len(self)):
86
+ yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
87
+
88
+ def __len__(self):
89
+ """
90
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
91
+ to the number of layers in the model.
92
+ """
93
+ return len(self.key_cache)
94
+
95
+ def update(
96
+ self,
97
+ key_states: torch.Tensor,
98
+ value_states: torch.Tensor,
99
+ layer_idx: int,
100
+ cache_kwargs: Optional[Dict[str, Any]] = None,
101
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
102
+ """
103
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
104
+
105
+ Parameters:
106
+ key_states (`torch.Tensor`):
107
+ The new key states to cache.
108
+ value_states (`torch.Tensor`):
109
+ The new value states to cache.
110
+ layer_idx (`int`):
111
+ The index of the layer to cache the states for.
112
+ cache_kwargs (`Dict[str, Any]`, `optional`):
113
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
114
+
115
+ Return:
116
+ A tuple containing the updated key and value states.
117
+ """
118
+ # Update the number of seen tokens
119
+ if layer_idx == 0:
120
+ self.seen_tokens += key_states.shape[-2]
121
+
122
+ # Update the cache
123
+ if len(self.key_cache) <= layer_idx:
124
+ self.key_cache.append(key_states)
125
+ self.value_cache.append(value_states)
126
+ else:
127
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
128
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
129
+
130
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
131
+
132
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
133
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
134
+ if len(self.key_cache) <= layer_idx:
135
+ return 0
136
+ return self.key_cache[layer_idx].shape[-2]
137
+
138
+ def get_max_length(self) -> Optional[int]:
139
+ """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
140
+ return None
141
+
142
+ def reorder_cache(self, beam_idx: torch.LongTensor):
143
+ """Reorders the cache for beam search, given the selected beam indices."""
144
+ for layer_idx in range(len(self.key_cache)):
145
+ device = self.key_cache[layer_idx].device
146
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
147
+ device = self.value_cache[layer_idx].device
148
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
149
+
150
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
151
+ """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
152
+ legacy_cache = ()
153
+ for layer_idx in range(len(self)):
154
+ legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
155
+ return legacy_cache
156
+
157
+ @classmethod
158
+ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
159
+ """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
160
+ cache = cls()
161
+ if past_key_values is not None:
162
+ for layer_idx in range(len(past_key_values)):
163
+ key_states, value_states = past_key_values[layer_idx]
164
+ cache.update(key_states, value_states, layer_idx)
165
+ return cache
166
+
167
+
168
+ class SinkCache(Cache):
169
+ """
170
+ A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
171
+ generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
172
+ tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
173
+
174
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
175
+ `[batch_size, num_heads, seq_len, head_dim]`.
176
+
177
+ Parameters:
178
+ window_length (`int`):
179
+ The length of the context window.
180
+ num_sink_tokens (`int`):
181
+ The number of sink tokens. See the original paper for more information.
182
+ """
183
+
184
+ def __init__(self, window_length: int, num_sink_tokens: int) -> None:
185
+ self.key_cache: List[torch.Tensor] = []
186
+ self.value_cache: List[torch.Tensor] = []
187
+ self.window_length = window_length
188
+ self.num_sink_tokens = num_sink_tokens
189
+ self.cos_sin_cache = {}
190
+ self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
191
+
192
+ @staticmethod
193
+ def _rotate_half(x):
194
+ x1 = x[..., : x.shape[-1] // 2]
195
+ x2 = x[..., x.shape[-1] // 2 :]
196
+ return torch.cat((-x2, x1), dim=-1)
197
+
198
+ def _apply_key_rotary_pos_emb(
199
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
200
+ ) -> torch.Tensor:
201
+ rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
202
+ return rotated_key_states
203
+
204
+ def _get_rerotation_cos_sin(
205
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
206
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
207
+ if key_states.shape[-2] not in self.cos_sin_cache:
208
+ # Upcast to float32 temporarily for better accuracy
209
+ cos = cos.to(torch.float32)
210
+ sin = sin.to(torch.float32)
211
+
212
+ # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
213
+ original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
214
+ shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
215
+ original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
216
+ shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
217
+ rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
218
+ rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
219
+
220
+ self.cos_sin_cache[key_states.shape[-2]] = (
221
+ rerotation_cos.to(key_states.dtype).unsqueeze(0),
222
+ rerotation_sin.to(key_states.dtype).unsqueeze(0),
223
+ )
224
+ return self.cos_sin_cache[key_states.shape[-2]]
225
+
226
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
227
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
228
+ # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
229
+ if len(self.key_cache) <= layer_idx:
230
+ return 0
231
+ return self.key_cache[layer_idx].shape[-2]
232
+
233
+ def get_max_length(self) -> Optional[int]:
234
+ """Returns the maximum sequence length of the cached states."""
235
+ return self.window_length
236
+
237
+ def update(
238
+ self,
239
+ key_states: torch.Tensor,
240
+ value_states: torch.Tensor,
241
+ layer_idx: int,
242
+ cache_kwargs: Optional[Dict[str, Any]] = None,
243
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
244
+ """
245
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
246
+
247
+ Parameters:
248
+ key_states (`torch.Tensor`):
249
+ The new key states to cache.
250
+ value_states (`torch.Tensor`):
251
+ The new value states to cache.
252
+ layer_idx (`int`):
253
+ The index of the layer to cache the states for.
254
+ cache_kwargs (`Dict[str, Any]`, `optional`):
255
+ Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
256
+ `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
257
+ rotation as the tokens are shifted.
258
+
259
+ Return:
260
+ A tuple containing the updated key and value states.
261
+ """
262
+ # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
263
+ # with partially rotated position embeddings, like Phi or Persimmon.
264
+ sin = cache_kwargs.get("sin")
265
+ cos = cache_kwargs.get("cos")
266
+ partial_rotation_size = cache_kwargs.get("partial_rotation_size")
267
+ using_rope = cos is not None and sin is not None
268
+
269
+ # Update the number of seen tokens
270
+ if layer_idx == 0:
271
+ self.seen_tokens += key_states.shape[-2]
272
+
273
+ # [bsz, num_heads, seq_len, head_dim]
274
+ if len(self.key_cache) <= layer_idx:
275
+ # Empty cache
276
+ self.key_cache.append(key_states)
277
+ self.value_cache.append(value_states)
278
+
279
+ elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
280
+ # Growing cache
281
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
282
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
283
+
284
+ else:
285
+ # Shifting cache
286
+ keys_to_keep = self.key_cache[layer_idx][
287
+ :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
288
+ ]
289
+
290
+ # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
291
+ if using_rope:
292
+ rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
293
+ key_states, cos[: self.window_length], sin[: self.window_length]
294
+ )
295
+ if partial_rotation_size is not None:
296
+ keys_to_keep, keys_pass = (
297
+ keys_to_keep[..., :partial_rotation_size],
298
+ keys_to_keep[..., partial_rotation_size:],
299
+ )
300
+ keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
301
+ if partial_rotation_size is not None:
302
+ keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
303
+
304
+ # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
305
+ sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
306
+ self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
307
+
308
+ sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
309
+ values_to_keep = self.value_cache[layer_idx][
310
+ :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
311
+ ]
312
+ self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
313
+
314
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
315
+
316
+ def reorder_cache(self, beam_idx: torch.LongTensor):
317
+ """Reorders the cache for beam search, given the selected beam indices."""
318
+ for layer_idx in range(len(self.key_cache)):
319
+ device = self.key_cache[layer_idx].device
320
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
321
+ device = self.value_cache[layer_idx].device
322
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
caps_boxes_coco2014_val_80.jsonl ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"id": "000000296284", "image": "000000296284.jpg", "captions": ["A donut shop is full of different flavors of donuts.", "Fruit flavored donuts lined up in a glass fronted cabinet", "A rack with some doughnuts in a glass case.", "A display case in a bakery filled with donuts.", "An assortment of doughnuts are arranged in a display case."], "instances": [{"category": "donut", "bbox": [0.37, 0.584, 0.504, 0.709]}, {"category": "donut", "bbox": [0.369, 0.22, 0.492, 0.317]}, {"category": "donut", "bbox": [0.471, 0.587, 0.639, 0.706]}, {"category": "donut", "bbox": [0.544, 0.213, 0.679, 0.316]}, {"category": "donut", "bbox": [0.035, 0.22, 0.196, 0.328]}, {"category": "donut", "bbox": [0.054, 0.608, 0.221, 0.711]}, {"category": "donut", "bbox": [0.283, 0.586, 0.429, 0.708]}, {"category": "donut", "bbox": [0.466, 0.226, 0.585, 0.32]}, {"category": "donut", "bbox": [0.28, 0.232, 0.393, 0.322]}, {"category": "donut", "bbox": [0.0, 0.609, 0.097, 0.722]}]}
2
+ {"id": "000000151358", "image": "000000151358.jpg", "captions": ["A newspaper that has sunglasses on top of it sitting in front of books.", "an apple sunglasses books and a teddy bear", "A folded newspaper and sunglasses are on a table with an apple, books, and teddy bear behind.", "An apple sitting on a table next to sunglasses and a news paper.", "There are sunglasses laying on the folded newspaper."], "instances": [{"category": "tie", "bbox": [0.258, 0.074, 0.527, 0.589]}, {"category": "apple", "bbox": [0.621, 0.482, 0.853, 0.645]}, {"category": "book", "bbox": [0.154, 0.107, 0.275, 0.59]}, {"category": "book", "bbox": [0.535, 0.09, 0.735, 0.583]}, {"category": "book", "bbox": [0.051, 0.112, 0.159, 0.6]}, {"category": "teddy bear", "bbox": [0.753, 0.084, 1.0, 0.517]}, {"category": "book", "bbox": [0.681, 0.097, 0.796, 0.483]}, {"category": "book", "bbox": [0.443, 0.099, 0.574, 0.588]}, {"category": "book", "bbox": [0.267, 0.337, 0.386, 0.579]}]}
3
+ {"id": "000000052312", "image": "000000052312.jpg", "captions": ["The old man literally has a toothbrush mustache.", "An old man with a tooth brush head under his nose, mimicking Hitler", "A man wearing a toothbrush for a moustache.", "A man with the head of a toothbrush under his nose like a mustache", "An elderly man wearing the head of a toothbrush as a moustache."], "instances": [{"category": "toothbrush", "bbox": [0.345, 0.59, 0.594, 0.679]}, {"category": "person", "bbox": [0.0, 0.03, 1.0, 0.99]}]}
4
+ {"id": "000000473210", "image": "000000473210.jpg", "captions": ["two people taking apart their wii controllers to replace batteries", "People taking apart video game remote controls on a table", "People handling a couple of remotes taking them apart.", "two sets of hands a wooden table and two controllers", "Two people who are taking apart a video game controller."], "instances": [{"category": "person", "bbox": [0.002, 0.334, 0.453, 0.986]}, {"category": "remote", "bbox": [0.407, 0.207, 0.727, 0.604]}, {"category": "remote", "bbox": [0.088, 0.344, 0.313, 0.547]}, {"category": "laptop", "bbox": [0.001, 0.049, 0.1, 0.197]}, {"category": "person", "bbox": [0.484, 0.254, 0.998, 0.985]}, {"category": "dining table", "bbox": [0.0, 0.003, 1.0, 0.956]}]}
5
+ {"id": "000000097131", "image": "000000097131.jpg", "captions": ["A car parked by a parking meter in front of a building.", "A car is sitting parked at a curb in front of a parking meter.", "A black car on the street next to a parking meter.", "A gray car parked in front of two parking meters.", "A black car parked on the side of the road."], "instances": [{"category": "car", "bbox": [0.227, 0.362, 0.946, 0.761]}, {"category": "car", "bbox": [0.793, 0.322, 0.88, 0.4]}, {"category": "car", "bbox": [0.0, 0.447, 0.028, 0.726]}, {"category": "parking meter", "bbox": [0.156, 0.35, 0.186, 0.453]}, {"category": "truck", "bbox": [0.907, 0.331, 1.0, 0.408]}, {"category": "parking meter", "bbox": [0.188, 0.349, 0.218, 0.448]}]}
6
+ {"id": "000000543364", "image": "000000543364.jpg", "captions": ["There is a table in the middle of the room.", "A room with a couch, table, lamp and a chaise.", "A living room with couch, chaise, track lighting, and a large window.", "A room with large windows, a couch and a table.", "A living room with lots of furniture and a large window."], "instances": [{"category": "dining table", "bbox": [0.388, 0.644, 0.636, 0.879]}, {"category": "couch", "bbox": [0.194, 0.531, 0.552, 0.777]}, {"category": "couch", "bbox": [0.568, 0.488, 0.907, 0.783]}, {"category": "remote", "bbox": [0.524, 0.651, 0.556, 0.675]}, {"category": "chair", "bbox": [0.661, 0.478, 0.802, 0.604]}]}
7
+ {"id": "000000217181", "image": "000000217181.jpg", "captions": ["They are standing next to some stylish motorcycles.", "Three men are standing around looking at sports motorcycles.", "A small group of men are standing around a motorcycle.", "Two men surrounding a blue motorcycle and others", "A few blue motorcycles are parked in a lot."], "instances": [{"category": "car", "bbox": [0.011, 0.177, 0.2, 0.336]}, {"category": "motorcycle", "bbox": [0.032, 0.139, 0.907, 0.982]}, {"category": "motorcycle", "bbox": [0.0, 0.239, 0.148, 0.613]}, {"category": "motorcycle", "bbox": [0.0, 0.301, 0.106, 0.45]}, {"category": "person", "bbox": [0.775, 0.043, 0.93, 0.463]}, {"category": "person", "bbox": [0.717, 0.116, 0.81, 0.509]}, {"category": "person", "bbox": [0.296, 0.008, 0.472, 0.325]}, {"category": "person", "bbox": [0.115, 0.19, 0.164, 0.269]}, {"category": "truck", "bbox": [0.63, 0.227, 0.731, 0.335]}]}
8
+ {"id": "000000140289", "image": "000000140289.jpg", "captions": ["Two born bears walking though a forest surrounded by trees.", "Two full grown brown bears in a habitat.", "Two bears are roaming around in the woods.", "Two bears around logs in front of a large rock.", "Two big bears wandering through the woods together"], "instances": [{"category": "bear", "bbox": [0.131, 0.269, 0.375, 0.65]}, {"category": "bear", "bbox": [0.568, 0.193, 0.809, 0.827]}]}
9
+ {"id": "000000460149", "image": "000000460149.jpg", "captions": ["A clock hosted on a pole on a pavement next to a building", "Street clock on quiet street with trees and bicycles.", "A tall clock stands on an empty sidewalk.", "A pole that has a clock on the top of it.", "a clock on a short tower and potted plants along the sidewalk"], "instances": [{"category": "potted plant", "bbox": [0.14, 0.71, 0.338, 0.856]}, {"category": "bicycle", "bbox": [0.65, 0.671, 0.766, 0.733]}, {"category": "car", "bbox": [0.38, 0.608, 0.488, 0.656]}, {"category": "clock", "bbox": [0.468, 0.048, 0.699, 0.216]}, {"category": "bicycle", "bbox": [0.669, 0.662, 0.719, 0.67]}, {"category": "car", "bbox": [0.786, 0.625, 0.86, 0.668]}, {"category": "potted plant", "bbox": [0.756, 0.637, 0.819, 0.682]}, {"category": "person", "bbox": [0.942, 0.615, 0.954, 0.641]}, {"category": "bicycle", "bbox": [0.648, 0.68, 0.714, 0.747]}, {"category": "car", "bbox": [0.837, 0.619, 0.88, 0.659]}, {"category": "potted plant", "bbox": [0.017, 0.197, 0.443, 0.686]}]}
10
+ {"id": "000000225738", "image": "000000225738.jpg", "captions": ["A group of giraffes standing up in their natural habitat.", "A group of giraffe standing in a grass field.", "A group of four giraffes near the same tree.", "there are four giraffes standing among some dry brush", "A herd of giraffe standing on top of a grass field."], "instances": [{"category": "giraffe", "bbox": [0.648, 0.231, 0.855, 0.915]}, {"category": "giraffe", "bbox": [0.33, 0.136, 0.521, 0.93]}, {"category": "giraffe", "bbox": [0.406, 0.261, 0.515, 1.0]}, {"category": "giraffe", "bbox": [0.347, 0.194, 0.583, 0.922]}]}
11
+ {"id": "000000109532", "image": "000000109532.jpg", "captions": ["An adorable husky dog sleeping in a dog bed next to a fan.", "A dark room with a dog sleeping on a dog bed.", "A dog is sleeping in a dark room.", "a large dog laying in a dog bed in a living room", "A dog sleeping on a dog bed in a room."], "instances": [{"category": "dog", "bbox": [0.426, 0.661, 0.582, 0.925]}, {"category": "potted plant", "bbox": [0.603, 0.261, 0.781, 0.613]}, {"category": "chair", "bbox": [0.67, 0.515, 0.899, 0.801]}, {"category": "potted plant", "bbox": [0.671, 0.439, 0.763, 0.612]}, {"category": "chair", "bbox": [0.852, 0.653, 0.948, 0.818]}]}
12
+ {"id": "000000118606", "image": "000000118606.jpg", "captions": ["A man riding skis on top of a rail.", "a person riding a pair of skis on a rail", "Someone on a pair of skis on a ramp at the ski slope", "Person with skis in the air above the snow.", "A man performing a trick on a rail while skiing."], "instances": [{"category": "person", "bbox": [0.444, 0.361, 0.537, 0.633]}, {"category": "skis", "bbox": [0.413, 0.554, 0.539, 0.664]}, {"category": "person", "bbox": [0.342, 0.585, 0.352, 0.62]}, {"category": "person", "bbox": [0.439, 0.565, 0.446, 0.58]}]}
13
+ {"id": "000000385873", "image": "000000385873.jpg", "captions": ["Three pizzas sitting next to each other in boxes.", "Two smaller pizzas sit beside a large pizza topped with tortilla chips.", "Three pizzas inside their delivery boxes, one with two side orders of sauce.", "One pizza is larger than two other pizzas.", "Three pizza boxes with pizza in them are open."], "instances": [{"category": "bowl", "bbox": [0.634, 0.624, 0.736, 0.752]}, {"category": "pizza", "bbox": [0.3, 0.382, 0.615, 0.733]}, {"category": "pizza", "bbox": [0.0, 0.4, 0.287, 0.745]}, {"category": "pizza", "bbox": [0.624, 0.279, 0.999, 0.753]}, {"category": "bowl", "bbox": [0.94, 0.247, 1.0, 0.352]}]}
14
+ {"id": "000000092109", "image": "000000092109.jpg", "captions": ["A giraffe's head is pictured in this clear, colorful photo.", "A giraffe is standing tall in the middle of several bright green trees", "The face of a giraffe looking to the side.", "the close up head shot of a giraffe", "this is a giraffe chewing on some leaves"], "instances": [{"category": "giraffe", "bbox": [0.236, 0.122, 1.0, 0.987]}]}
15
+ {"id": "000000163076", "image": "000000163076.jpg", "captions": ["There's an outdoor dining area featuring a fountain.", "A table sitting next to a water fountain covered by an umbrella.", "An empty restaurant patio with tables and umbrellas.", "An outdoor restaurant with a fountain at night", "A fountain bubbles in the plaza of an outdoor cafe."], "instances": [{"category": "umbrella", "bbox": [0.064, 0.069, 0.95, 0.844]}, {"category": "chair", "bbox": [0.198, 0.574, 0.355, 0.704]}, {"category": "chair", "bbox": [0.42, 0.571, 0.55, 0.738]}, {"category": "dining table", "bbox": [0.066, 0.741, 0.766, 0.925]}, {"category": "dining table", "bbox": [0.059, 0.584, 0.27, 0.659]}, {"category": "chair", "bbox": [0.432, 0.567, 0.52, 0.624]}, {"category": "chair", "bbox": [0.433, 0.555, 0.504, 0.6]}, {"category": "chair", "bbox": [0.109, 0.673, 0.374, 0.796]}]}
16
+ {"id": "000000560371", "image": "000000560371.jpg", "captions": ["Street signs from the corner of 8th ave. and 22 3/4 st.", "A two way street sign with one sign that changes from one name to another.", "A street sign is pointing towards 8th avenue and the other is pointing towards 22 3/4 street in the middle of the forest.", "A street sign standing in front of some trees.", "Peculiar street sign showing intersection of 23 3/4 St and 8th Ave/CTH D."], "instances": []}
17
+ {"id": "000000367571", "image": "000000367571.jpg", "captions": ["A couple of different doughnuts in a box", "There are four donuts in a box, and some are cake donuts and a doughnut with nuts and coconut on top.", "A box of glazed doughnuts on a table.", "Three donuts with toppings on them sitting inside a box.", "A box that is filled with different kinds of doughnuts."], "instances": [{"category": "donut", "bbox": [0.412, 0.335, 0.711, 0.681]}, {"category": "donut", "bbox": [0.093, 0.493, 0.486, 0.922]}, {"category": "donut", "bbox": [0.713, 0.423, 0.957, 0.874]}, {"category": "donut", "bbox": [0.13, 0.331, 0.397, 0.55]}]}
18
+ {"id": "000000580197", "image": "000000580197.jpg", "captions": ["Two men in bow ties standing next to steel rafter.", "Several men in suits talking together in a room.", "An older man in a tuxedo standing next to a younger man in a tuxedo wearing glasses.", "Two men wearing tuxedos glance at each other.", "Older man in tuxedo sitting next to another younger man in tuxedo."], "instances": [{"category": "tie", "bbox": [0.914, 0.46, 0.984, 0.512]}, {"category": "person", "bbox": [0.297, 0.638, 0.71, 0.989]}, {"category": "person", "bbox": [0.77, 0.177, 1.0, 0.971]}, {"category": "tie", "bbox": [0.281, 0.481, 0.368, 0.519]}, {"category": "person", "bbox": [0.103, 0.204, 0.497, 1.0]}]}
19
+ {"id": "000000506095", "image": "000000506095.jpg", "captions": ["A cat is staring at a laptop computer.", "a cat on a desk with a laptop and a mouse", "A cat that is sitting at a desk next to a laptop.", "A kitten sitting on a laptop computer sitting on top of a wooden desk.", "A kitten sits facing an open black laptop."], "instances": [{"category": "cat", "bbox": [0.658, 0.207, 1.0, 0.754]}, {"category": "laptop", "bbox": [0.108, 0.135, 0.766, 0.69]}, {"category": "book", "bbox": [0.836, 0.239, 0.954, 0.273]}, {"category": "book", "bbox": [0.0, 0.556, 0.128, 0.685]}, {"category": "book", "bbox": [0.039, 0.574, 0.257, 0.691]}, {"category": "book", "bbox": [0.825, 0.214, 0.962, 0.254]}, {"category": "book", "bbox": [0.892, 0.275, 0.958, 0.308]}, {"category": "book", "bbox": [0.922, 0.318, 0.986, 0.353]}, {"category": "book", "bbox": [0.87, 0.267, 0.951, 0.291]}, {"category": "book", "bbox": [0.949, 0.102, 0.976, 0.114]}, {"category": "book", "bbox": [0.936, 0.161, 0.958, 0.168]}]}
20
+ {"id": "000000024996", "image": "000000024996.jpg", "captions": ["A bathroom with a glass door and a sink.", "A blue lined bathroom with an open glass door.", "A nice bathroom with a sink, toilet, and tiled shower.", "A bathroom that is clean and shiny in the day.", "a bathroom with a sink and a mirror and a window"], "instances": [{"category": "toilet", "bbox": [0.842, 0.934, 0.95, 1.0]}, {"category": "sink", "bbox": [0.506, 0.724, 0.683, 0.834]}]}
21
+ {"id": "000000457882", "image": "000000457882.jpg", "captions": ["a girl in a bikini and a brown and white dog and a few other people", "A woman with a swimsuit on sitting with a dog.", "A woman is sitting with a dog on her lap.", "A dog sitting next to a woman in her swimsuit.", "WOMAN SITTING WITH HER DOG, AND OTHER WOMEN ARE AROUND"], "instances": [{"category": "dog", "bbox": [0.202, 0.409, 0.54, 0.81]}, {"category": "dog", "bbox": [0.61, 0.428, 0.729, 0.723]}, {"category": "boat", "bbox": [0.003, 0.705, 0.939, 0.974]}, {"category": "person", "bbox": [0.236, 0.001, 0.558, 0.784]}, {"category": "person", "bbox": [0.681, 0.001, 0.957, 0.798]}, {"category": "person", "bbox": [0.849, 0.478, 1.0, 0.946]}, {"category": "person", "bbox": [0.345, 0.187, 0.634, 0.828]}, {"category": "person", "bbox": [0.033, 0.345, 0.109, 0.434]}]}
22
+ {"id": "000000081552", "image": "000000081552.jpg", "captions": ["A cat sitting and curled up on a red couch", "A cat laying on a red couch sleeping.", "a tan and black cat curled up asleep on a red velvet seat", "A cat is curled up on a red sofa.", "Cat curled up, sleeping on a red plush couch."], "instances": [{"category": "cat", "bbox": [0.412, 0.237, 0.634, 0.482]}, {"category": "couch", "bbox": [0.003, 0.005, 1.0, 0.99]}]}
23
+ {"id": "000000273450", "image": "000000273450.jpg", "captions": ["A person flipping of a parking meter on the side of a road.", "A man holds up his middle finger to a parking meter.", "Person giving the middle finger to a parking meter.", "a black silver white blue red an orange parking meter and a hand flipping it off", "A person is flipping off a parking meter."], "instances": [{"category": "person", "bbox": [0.0, 0.475, 0.565, 0.987]}, {"category": "car", "bbox": [0.0, 0.0, 0.531, 0.734]}, {"category": "parking meter", "bbox": [0.0, 0.0, 1.0, 0.987]}]}
24
+ {"id": "000000203879", "image": "000000203879.jpg", "captions": ["There is a small cellphone displayed between a set of ear buds and two paper weights.", "a cell phone lays next to some diamonds", "a close up of a cell phone on a table near earbuds", "A cell phone sits on a table next to some jewels.", "A cell phone, ear buds, and two jewels laying near each other."], "instances": [{"category": "cell phone", "bbox": [0.322, 0.233, 0.62, 0.79]}]}
25
+ {"id": "000000346875", "image": "000000346875.jpg", "captions": ["two zebras in a field near one another", "A couple of zebra walking across a green field.", "Two zebra are walking near a gravel road.", "two zebras in a green field of grass and some trees", "A zebra follows another zebra through a park."], "instances": [{"category": "zebra", "bbox": [0.591, 0.263, 0.82, 0.466]}, {"category": "zebra", "bbox": [0.293, 0.243, 0.561, 0.45]}]}
26
+ {"id": "000000525439", "image": "000000525439.jpg", "captions": ["a man stands in front of a flipped skate boarder", "A man standing next to a skateboard that is laying on the ground wheels pointed up.", "Skateboard laying upside down on cement with someone standing next to it.", "A boy in camo shorts stands before an overturned skateboard.", "a person with an upside down skate board"], "instances": [{"category": "person", "bbox": [0.307, 0.001, 0.63, 0.739]}, {"category": "skateboard", "bbox": [0.0, 0.592, 0.626, 0.969]}]}
27
+ {"id": "000000304749", "image": "000000304749.jpg", "captions": ["The woman is taking a picture in the bathroom mirror.", "A picture of a woman in a mirror.", "A woman's midsection reflected in a round mirror.", "A circular mirror reflecting a woman's stomach in turquoise shirt.", "A selfie taken of a person from the neck down."], "instances": [{"category": "person", "bbox": [0.092, 0.001, 0.646, 0.496]}]}
28
+ {"id": "000000323760", "image": "000000323760.jpg", "captions": ["A toilet is shown in a bare room.", "A ugly bathroom with a section of the wall missing.", "A toilet in a stripped bathroom with studs, bricks and plaster showing", "A bathroom with no walls and a toilet bowl", "A white toilet next to some torn out walls."], "instances": [{"category": "toilet", "bbox": [0.167, 0.585, 0.714, 1.0]}]}
29
+ {"id": "000000066144", "image": "000000066144.jpg", "captions": ["A woman standing in front of window next to a bug and a stop sign.", "A car parked on the street next to a tree and stop sign.", "A lone Volkswagen is parked by a stop sign.", "A window view of a small car near a street stop sign.", "An old VW Bug standing at a stop sign."], "instances": [{"category": "stop sign", "bbox": [0.501, 0.328, 0.569, 0.428]}, {"category": "car", "bbox": [0.242, 0.488, 0.56, 0.726]}, {"category": "car", "bbox": [0.279, 0.325, 0.33, 0.363]}, {"category": "car", "bbox": [0.153, 0.333, 0.29, 0.405]}, {"category": "car", "bbox": [0.11, 0.339, 0.177, 0.373]}, {"category": "car", "bbox": [0.0, 0.654, 0.082, 0.826]}, {"category": "car", "bbox": [0.0, 0.322, 0.064, 0.364]}, {"category": "car", "bbox": [0.451, 0.333, 0.51, 0.392]}]}
30
+ {"id": "000000455772", "image": "000000455772.jpg", "captions": ["A person in a field jumping to catch a Frisbee.", "A guy jumping to catch a frisbee in mid-air.", "A person that is trying to get a frisbee.", "Nice reach, but the Frisbee flies on, victorious.", "A man playing frisbee in a grassy yard."], "instances": [{"category": "car", "bbox": [0.148, 0.339, 0.201, 0.476]}, {"category": "car", "bbox": [0.376, 0.396, 0.424, 0.476]}, {"category": "person", "bbox": [0.547, 0.122, 0.698, 0.904]}, {"category": "frisbee", "bbox": [0.479, 0.154, 0.555, 0.231]}, {"category": "car", "bbox": [0.001, 0.299, 0.085, 0.394]}]}
31
+ {"id": "000000511117", "image": "000000511117.jpg", "captions": ["A couple of kids standing on top of a grass covered field.", "A little boy wearing a baseball uniform stands by a little girl.", "A young boy in a baseball uniform and a young girl are standing in front of a chain link fence.", "A little boy and girl standing on a baseball field. The boy has a uniform on.", "A young baseball player is standing next to a young girl."], "instances": [{"category": "person", "bbox": [0.514, 0.178, 0.776, 0.774]}, {"category": "baseball glove", "bbox": [0.468, 0.462, 0.593, 0.609]}, {"category": "person", "bbox": [0.174, 0.051, 0.598, 0.839]}, {"category": "bench", "bbox": [0.558, 0.125, 1.0, 0.315]}]}
32
+ {"id": "000000207151", "image": "000000207151.jpg", "captions": ["A vegetarian pizza is half eaten on a pizza holder.", "A couple of pieces of pizza with vegetable slices on them.", "A wooden pan serving tray with a pizza on it.", "A pizza on a cutting board is half gone.", "A Pizza is nearly finished with only three pieces left."], "instances": [{"category": "bottle", "bbox": [0.001, 0.001, 0.121, 0.231]}, {"category": "cup", "bbox": [0.0, 0.002, 0.121, 0.238]}, {"category": "pizza", "bbox": [0.17, 0.472, 0.526, 0.82]}, {"category": "pizza", "bbox": [0.398, 0.106, 0.962, 0.679]}, {"category": "dining table", "bbox": [0.0, 0.001, 1.0, 0.988]}]}
33
+ {"id": "000000431165", "image": "000000431165.jpg", "captions": ["A baby elephant standing in front of a brick building.", "An elephant is standing near a dirt mount in an exhibit.", "Grey elephant standing next to a large sand dune in a pen.", "An elephant standing alone inside of an enclosure.", "The baby elephant is alone in the pen."], "instances": [{"category": "elephant", "bbox": [0.303, 0.399, 0.638, 0.78]}]}
34
+ {"id": "000000378545", "image": "000000378545.jpg", "captions": ["A pole that has a clock on top of it.", "A clock mounted on an outdoor post with Roman numerals.", "a clock on a pole saying it is 12:45", "An ornamental standing clock is at the foreground of a row of houses.", "A black and gold clock on a pole in front of a building."], "instances": [{"category": "clock", "bbox": [0.216, 0.249, 0.749, 0.658]}]}
35
+ {"id": "000000555904", "image": "000000555904.jpg", "captions": ["A man sitting at a bar filled with liquor.", "People sitting a a take near several bottles of wine on shelves.", "Several people are sitting at a table drinking.", "Several people in a bar sitting at a long table.", "People eating in a restaurant near wine bottles."], "instances": [{"category": "dining table", "bbox": [0.123, 0.663, 0.317, 0.811]}, {"category": "person", "bbox": [0.715, 0.239, 1.0, 0.998]}, {"category": "person", "bbox": [0.142, 0.528, 0.281, 0.742]}, {"category": "person", "bbox": [0.529, 0.53, 0.606, 0.69]}, {"category": "person", "bbox": [0.705, 0.518, 0.796, 0.673]}, {"category": "wine glass", "bbox": [0.247, 0.669, 0.27, 0.718]}, {"category": "person", "bbox": [0.281, 0.524, 0.534, 1.0]}, {"category": "bottle", "bbox": [0.168, 0.346, 0.189, 0.425]}, {"category": "bottle", "bbox": [0.379, 0.264, 0.431, 0.433]}, {"category": "bottle", "bbox": [0.252, 0.313, 0.277, 0.429]}, {"category": "bottle", "bbox": [0.294, 0.295, 0.326, 0.43]}, {"category": "bottle", "bbox": [0.589, 0.35, 0.613, 0.444]}, {"category": "bottle", "bbox": [0.433, 0.281, 0.473, 0.437]}, {"category": "bottle", "bbox": [0.478, 0.289, 0.513, 0.44]}, {"category": "wine glass", "bbox": [0.688, 0.615, 0.709, 0.69]}, {"category": "cup", "bbox": [0.589, 0.647, 0.612, 0.693]}, {"category": "person", "bbox": [0.732, 0.356, 0.953, 0.806]}, {"category": "bottle", "bbox": [0.555, 0.337, 0.585, 0.438]}, {"category": "bottle", "bbox": [0.337, 0.29, 0.378, 0.432]}, {"category": "bottle", "bbox": [0.21, 0.333, 0.232, 0.426]}, {"category": "bottle", "bbox": [0.134, 0.36, 0.148, 0.422]}, {"category": "bottle", "bbox": [0.516, 0.312, 0.557, 0.439]}, {"category": "cup", "bbox": [0.231, 0.718, 0.26, 0.763]}, {"category": "chair", "bbox": [0.517, 0.828, 0.65, 0.999]}, {"category": "chair", "bbox": [0.643, 0.804, 0.738, 0.841]}, {"category": "chair", "bbox": [0.347, 0.908, 0.519, 1.0]}, {"category": "chair", "bbox": [0.64, 0.806, 0.74, 0.998]}, {"category": "cup", "bbox": [0.205, 0.692, 0.232, 0.767]}, {"category": "dining table", "bbox": [0.536, 0.676, 0.743, 0.838]}, {"category": "person", "bbox": [0.002, 0.501, 0.263, 0.987]}, {"category": "bottle", "bbox": [0.531, 0.461, 0.542, 0.526]}, {"category": "bottle", "bbox": [0.237, 0.354, 0.702, 0.629]}]}
36
+ {"id": "000000415393", "image": "000000415393.jpg", "captions": ["a man on a skate board looks like he is falling", "A man does a skateboard trick on a skateboard ramp", "Guy falling off a skateboard in a room.", "A man riding a skateboard on top of a table.", "a man skating on part of a ramp with his skateboard"], "instances": [{"category": "person", "bbox": [0.361, 0.016, 0.809, 0.888]}, {"category": "skateboard", "bbox": [0.606, 0.809, 0.889, 0.901]}, {"category": "person", "bbox": [0.479, 0.091, 0.576, 0.386]}, {"category": "person", "bbox": [0.047, 0.441, 0.197, 0.759]}, {"category": "person", "bbox": [0.038, 0.453, 0.076, 0.545]}, {"category": "person", "bbox": [0.249, 0.307, 0.311, 0.591]}]}
37
+ {"id": "000000161011", "image": "000000161011.jpg", "captions": ["Three skiers posing for a picture on the slope.", "Three skiers pause for a photo at the top of a mountain.", "Three people standing on a mountain taking a picture as they ski.", "A woman and two men on skis on a snowy hillside surrounded by trees", "Three skiers have stopped to pose for a picture."], "instances": [{"category": "person", "bbox": [0.36, 0.321, 0.509, 0.82]}, {"category": "person", "bbox": [0.179, 0.281, 0.349, 0.795]}, {"category": "person", "bbox": [0.611, 0.292, 0.751, 0.809]}, {"category": "skis", "bbox": [0.595, 0.743, 0.732, 0.961]}, {"category": "skis", "bbox": [0.341, 0.724, 0.621, 0.907]}, {"category": "skis", "bbox": [0.212, 0.705, 0.398, 0.905]}]}
38
+ {"id": "000000284296", "image": "000000284296.jpg", "captions": ["Three giraffe's leaning over to get a sip of water.", "an image of a herd of giraffes in the water", "three giraffes banding down to drink water with trees in the background", "Three giraffe drinking from a pond with brush in back.", "Giraffes leaning down to drink at a watering hole"], "instances": [{"category": "giraffe", "bbox": [0.624, 0.387, 0.822, 0.635]}, {"category": "giraffe", "bbox": [0.4, 0.326, 0.561, 0.58]}, {"category": "giraffe", "bbox": [0.152, 0.291, 0.343, 0.551]}]}
39
+ {"id": "000000056013", "image": "000000056013.jpg", "captions": ["a number of luggage bags on a cart in a lobby", "Wheeled cart with luggage at lobby of commercial business.", "Trolley used for transporting personal luggage to guests rooms.", "A luggage cart topped with lots of luggage.", "a cart filled with suitcases and bags"], "instances": [{"category": "backpack", "bbox": [0.276, 0.52, 0.456, 0.678]}, {"category": "suitcase", "bbox": [0.41, 0.58, 0.597, 0.827]}, {"category": "suitcase", "bbox": [0.173, 0.645, 0.363, 0.836]}, {"category": "person", "bbox": [0.959, 0.297, 1.0, 0.478]}, {"category": "suitcase", "bbox": [0.526, 0.519, 0.712, 0.706]}, {"category": "person", "bbox": [0.762, 0.253, 0.871, 0.46]}, {"category": "backpack", "bbox": [0.517, 0.514, 0.694, 0.698]}, {"category": "handbag", "bbox": [0.316, 0.181, 0.431, 0.426]}, {"category": "suitcase", "bbox": [0.747, 0.453, 0.858, 0.557]}]}
40
+ {"id": "000000293505", "image": "000000293505.jpg", "captions": ["A person on a motor bike next to a cow.", "A woman riding a motorcycle down a dirt road.", "there is a woman riding a scooter down a dirt road", "A woman on a moped, two men and animals walking down the road.", "A woman on a motorcycle is next to a man walking a dog along with other people going down a dirt road."], "instances": [{"category": "cow", "bbox": [0.602, 0.472, 0.721, 0.816]}, {"category": "motorcycle", "bbox": [0.402, 0.512, 0.516, 0.788]}, {"category": "person", "bbox": [0.408, 0.4, 0.514, 0.639]}, {"category": "person", "bbox": [0.754, 0.301, 1.0, 1.0]}, {"category": "person", "bbox": [0.705, 0.415, 0.789, 0.714]}, {"category": "cow", "bbox": [0.347, 0.44, 0.373, 0.509]}, {"category": "cow", "bbox": [0.361, 0.436, 0.381, 0.501]}]}
41
+ {"id": "000000305873", "image": "000000305873.jpg", "captions": ["A little girl holding a red black dotted umbrella.", "A little girl with rain boots and a rain jacket on and an open umbrella to match her jacket.", "a little girl holding onto a lady bug pattern umbrella", "The child wears a labybug rain coat with a matching umbrella.", "A little girl wearing a ladybug raincoat and green rubber boots holding a ladybug umbrella"], "instances": [{"category": "umbrella", "bbox": [0.246, 0.002, 0.992, 0.415]}, {"category": "person", "bbox": [0.35, 0.132, 0.699, 0.791]}, {"category": "car", "bbox": [0.614, 0.0, 1.0, 0.465]}]}
42
+ {"id": "000000034096", "image": "000000034096.jpg", "captions": ["A house being built with lots of wood.", "A big pile of building material is placed on the floor in the wooden structure.", "A partially-built house with wooden studs and staircase in view.", "A house full of wood getting built at the moment.", "The beginning stages of a home still being made."], "instances": [{"category": "bed", "bbox": [0.505, 0.42, 0.721, 0.59]}, {"category": "tv", "bbox": [0.192, 0.441, 0.335, 0.606]}]}
43
+ {"id": "000000165257", "image": "000000165257.jpg", "captions": ["A large black counter top sitting next to a sink.", "a clean kitchen counter with a clean sink", "A kitchen with a sink, dishwasher and some boxes on the counter.", "A kitchen with a sink, dishwasher and boxes on the counter.", "a black counter on a wood cabinet in a kitchen", "a new kitchen cabinet with a sink being installed"], "instances": [{"category": "sink", "bbox": [0.513, 0.243, 0.718, 0.314]}]}
44
+ {"id": "000000431026", "image": "000000431026.jpg", "captions": ["a street sign on a city street near some tall bushes", "street signs on a metal pole lining a sidewalk lined with shrubbery.", "a large hedge of bushes on a corner near a street sign.", "Two street signs on sidewalk next to bushes and trees.", "Street signs along a well manicured street with large houses."], "instances": []}
45
+ {"id": "000000524575", "image": "000000524575.jpg", "captions": ["Three giraffe and a wildebeest in a field.", "A moose and several giraffes are grazing in the field.", "Zebras in the wild with a wildebeest behind them", "Two giraffe and a ox standing in a field eating grass.", "Giraffes and other safari animals graze in a sunlit field."], "instances": [{"category": "cow", "bbox": [0.46, 0.716, 0.643, 0.999]}, {"category": "giraffe", "bbox": [0.285, 0.5, 0.401, 0.826]}, {"category": "giraffe", "bbox": [0.083, 0.554, 0.179, 0.821]}, {"category": "giraffe", "bbox": [0.887, 0.481, 0.968, 0.715]}]}
46
+ {"id": "000000326550", "image": "000000326550.jpg", "captions": ["Black and white photograph of a person holding a surfboard by water.", "A person with a surfboard standing next to the water.", "A surfer stands on the rocks watching a wave crash.", "A man standing on a beach holding a surfboard.", "a person looking at the waves ready to surf"], "instances": [{"category": "person", "bbox": [0.327, 0.461, 0.492, 0.897]}, {"category": "surfboard", "bbox": [0.282, 0.56, 0.606, 0.741]}, {"category": "person", "bbox": [0.924, 0.352, 0.933, 0.362]}, {"category": "person", "bbox": [0.912, 0.348, 0.919, 0.36]}]}
47
+ {"id": "000000018476", "image": "000000018476.jpg", "captions": ["A tie that is sitting on top of a shirt.", "This photograph appears to be looking truly wonderful.", "a uniform complete with shoes laying on a bed", "Suit laid out with a red tie, white shirt and black shoes.", "a white shirt a red tie and some black shoes"], "instances": [{"category": "tie", "bbox": [0.457, 0.09, 0.853, 0.984]}, {"category": "bed", "bbox": [0.005, 0.005, 1.0, 0.379]}]}
48
+ {"id": "000000480652", "image": "000000480652.jpg", "captions": ["These suitcases are sitting next to a chair.", "An assortment of luggage bags stacked by a kitchen chair.", "A stack of luggage by a chair and table.", "a table and chair with several pieces of luggage nearby", "A pile of luggage sitting on the floor."], "instances": [{"category": "chair", "bbox": [0.483, 0.192, 1.0, 0.769]}, {"category": "backpack", "bbox": [0.433, 0.429, 0.742, 0.856]}, {"category": "suitcase", "bbox": [0.059, 0.414, 0.453, 0.841]}, {"category": "handbag", "bbox": [0.19, 0.184, 0.779, 0.475]}, {"category": "suitcase", "bbox": [0.175, 0.204, 0.583, 0.462]}]}
49
+ {"id": "000000012748", "image": "000000012748.jpg", "captions": ["A man and child next to a horse.", "a little boy touching the nose of a brown horse", "A man holding a baby whose petting a horse.", "a man letting his baby pet a horse", "man holding a baby and petting a horse"], "instances": [{"category": "horse", "bbox": [0.003, 0.079, 0.504, 0.868]}, {"category": "person", "bbox": [0.452, 0.294, 1.0, 0.989]}, {"category": "person", "bbox": [0.46, 0.217, 1.0, 0.988]}]}
50
+ {"id": "000000247840", "image": "000000247840.jpg", "captions": ["Large group of people standing outside a restaurant together.", "A dairy queen has people standing outside waiting", "an image of people standing outside and ice cream store", "Several people are lined up outside of a store.", "The front of a Dairy Queen restaurant with people entering the side."], "instances": [{"category": "fire hydrant", "bbox": [0.774, 0.674, 0.83, 0.807]}, {"category": "person", "bbox": [0.741, 0.465, 0.824, 0.755]}, {"category": "person", "bbox": [0.806, 0.471, 0.839, 0.722]}, {"category": "person", "bbox": [0.831, 0.499, 0.866, 0.726]}, {"category": "bench", "bbox": [0.061, 0.69, 0.219, 0.768]}, {"category": "handbag", "bbox": [0.859, 0.558, 0.877, 0.603]}, {"category": "person", "bbox": [0.719, 0.504, 0.75, 0.626]}, {"category": "potted plant", "bbox": [0.7, 0.648, 0.764, 0.743]}, {"category": "handbag", "bbox": [0.827, 0.548, 0.837, 0.577]}, {"category": "sandwich", "bbox": [0.359, 0.618, 0.417, 0.694]}]}
51
+ {"id": "000000399452", "image": "000000399452.jpg", "captions": ["a sandwhich sitting on a plate next to a glass of tea, bowl of soup", "a sandwich on a white plate a drink on a brown table", "A sandwich and chips sit on a white plate.", "a large plate of food with a glass of soda by it", "A sandwich sitting on top of a white plate next to a cup of coffee."], "instances": [{"category": "sandwich", "bbox": [0.175, 0.326, 0.605, 0.71]}, {"category": "cup", "bbox": [0.504, 0.024, 0.687, 0.419]}, {"category": "knife", "bbox": [0.742, 0.283, 0.857, 0.376]}, {"category": "spoon", "bbox": [0.618, 0.46, 0.797, 0.809]}, {"category": "fork", "bbox": [0.684, 0.254, 0.805, 0.395]}, {"category": "bowl", "bbox": [0.782, 0.366, 1.0, 0.62]}, {"category": "chair", "bbox": [0.202, 0.0, 0.671, 0.148]}, {"category": "dining table", "bbox": [0.002, 0.126, 0.996, 0.987]}]}
52
+ {"id": "000000515716", "image": "000000515716.jpg", "captions": ["A couple of women standing on either side of a man wearing glasses.", "Two women and a man are holding glasses up at a wine tasting.", "Three young adults holding wine glasses while standing at a bar.", "A group of people sit holding glasses and smiling at a table with several bottles.", "A group of people at a celebration having a taste of wine."], "instances": [{"category": "bottle", "bbox": [0.529, 0.604, 0.637, 0.908]}, {"category": "bottle", "bbox": [0.379, 0.398, 0.481, 0.892]}, {"category": "bottle", "bbox": [0.942, 0.464, 0.988, 0.653]}, {"category": "person", "bbox": [0.0, 0.126, 0.136, 0.811]}, {"category": "person", "bbox": [0.05, 0.093, 0.211, 0.471]}, {"category": "person", "bbox": [0.401, 0.031, 0.678, 0.683]}, {"category": "person", "bbox": [0.617, 0.191, 0.94, 0.858]}, {"category": "person", "bbox": [0.723, 0.098, 0.947, 0.564]}, {"category": "wine glass", "bbox": [0.634, 0.434, 0.697, 0.628]}, {"category": "wine glass", "bbox": [0.285, 0.346, 0.372, 0.558]}, {"category": "wine glass", "bbox": [0.522, 0.422, 0.583, 0.544]}, {"category": "handbag", "bbox": [0.704, 0.601, 1.0, 0.916]}, {"category": "person", "bbox": [0.944, 0.319, 0.999, 0.604]}, {"category": "bottle", "bbox": [0.921, 0.46, 0.953, 0.636]}, {"category": "person", "bbox": [0.116, 0.171, 0.41, 0.829]}]}
53
+ {"id": "000000116173", "image": "000000116173.jpg", "captions": ["The boy is on his surfboard in the water riding it.", "a young boy riding a boogie board in the water", "A boy riding surf board in the ocean.", "A young boy is riding a surfboard on a small wave.", "A young boy is surfing in the ocean."], "instances": [{"category": "person", "bbox": [0.485, 0.238, 0.702, 0.821]}, {"category": "person", "bbox": [0.866, 0.223, 0.921, 0.29]}, {"category": "person", "bbox": [0.752, 0.146, 0.775, 0.188]}, {"category": "surfboard", "bbox": [0.239, 0.758, 0.782, 0.846]}, {"category": "surfboard", "bbox": [0.853, 0.277, 0.981, 0.29]}, {"category": "surfboard", "bbox": [0.727, 0.169, 0.801, 0.198]}, {"category": "person", "bbox": [0.637, 0.194, 0.677, 0.261]}]}
54
+ {"id": "000000186013", "image": "000000186013.jpg", "captions": ["A beach scene includes many different kites flying in a cloudy sky.", "Kites being flown at the beach at twilight.", "A beach with flags in the ground and kites overhead in the sky.", "A beach with rows of flags in the sand and kites flying overhead.", "A beach filled with kites and wind sails next to the ocean."], "instances": [{"category": "kite", "bbox": [0.174, 0.4, 0.351, 0.483]}, {"category": "kite", "bbox": [0.144, 0.13, 0.273, 0.17]}, {"category": "kite", "bbox": [0.236, 0.269, 0.268, 0.294]}, {"category": "kite", "bbox": [0.464, 0.204, 0.598, 0.271]}, {"category": "kite", "bbox": [0.61, 0.304, 0.659, 0.342]}, {"category": "kite", "bbox": [0.545, 0.435, 0.565, 0.452]}, {"category": "kite", "bbox": [0.027, 0.558, 0.151, 0.59]}, {"category": "kite", "bbox": [0.93, 0.429, 0.973, 0.536]}, {"category": "kite", "bbox": [0.684, 0.36, 0.697, 0.374]}, {"category": "surfboard", "bbox": [0.393, 0.627, 0.446, 0.934]}, {"category": "person", "bbox": [0.959, 0.685, 0.984, 0.713]}, {"category": "person", "bbox": [0.919, 0.681, 0.94, 0.725]}, {"category": "person", "bbox": [0.8, 0.597, 0.805, 0.61]}, {"category": "person", "bbox": [0.079, 0.928, 0.116, 0.975]}, {"category": "kite", "bbox": [0.743, 0.307, 0.755, 0.319]}, {"category": "kite", "bbox": [0.78, 0.322, 0.795, 0.335]}, {"category": "kite", "bbox": [0.536, 0.526, 0.597, 0.617]}, {"category": "person", "bbox": [0.941, 0.694, 0.961, 0.726]}, {"category": "kite", "bbox": [0.575, 0.446, 0.594, 0.471]}]}
55
+ {"id": "000000015029", "image": "000000015029.jpg", "captions": ["A man holding a white frisbee standing on top of a field.", "A man is playing frisbee next to a tent.", "Guy at the park holding a frisbee with people in the back under a tent", "A man is holding a Frisbee standing in the grass.", "Young adult male holding a frisbee at an event."], "instances": [{"category": "frisbee", "bbox": [0.138, 0.359, 0.215, 0.587]}, {"category": "person", "bbox": [0.16, 0.002, 0.726, 0.995]}, {"category": "person", "bbox": [0.81, 0.73, 0.852, 0.825]}, {"category": "person", "bbox": [0.786, 0.749, 0.833, 0.814]}, {"category": "person", "bbox": [0.847, 0.743, 0.89, 0.804]}, {"category": "person", "bbox": [0.614, 0.749, 0.706, 0.936]}]}
56
+ {"id": "000000500565", "image": "000000500565.jpg", "captions": ["A woman holding a child wrapped in a towel brushing her teeth.", "A woman is holding a baby who is wrapped in a towel and holding a toothbrush", "A woman holding a little boy who is brushing his teeth.", "A baby with a toothbrush in his mouth while being held by a woman", "a close up of an adult holding a child brushing their teeth"], "instances": [{"category": "toothbrush", "bbox": [0.586, 0.66, 0.754, 0.821]}, {"category": "person", "bbox": [0.002, 0.007, 0.637, 0.991]}, {"category": "person", "bbox": [0.357, 0.196, 0.998, 0.984]}]}
57
+ {"id": "000000297323", "image": "000000297323.jpg", "captions": ["Two buses are parked against a curb in front of a building.", "Two automobiles parked on the side of a building.", "two tourist buses parked on street in front of old industrial building", "Two unique city buses stopped at a stop sign.", "Buses parked outside by a building and stop sign."], "instances": [{"category": "bus", "bbox": [0.7, 0.711, 0.92, 0.881]}, {"category": "person", "bbox": [0.936, 0.771, 0.972, 0.833]}, {"category": "stop sign", "bbox": [0.237, 0.666, 0.285, 0.728]}, {"category": "bus", "bbox": [0.334, 0.71, 0.678, 0.935]}, {"category": "truck", "bbox": [0.335, 0.72, 0.683, 0.934]}, {"category": "person", "bbox": [0.34, 0.791, 0.367, 0.834]}]}
58
+ {"id": "000000441147", "image": "000000441147.jpg", "captions": ["Two antique suitcases sit stacked one on top of the other.", "Two suitcases are stacked on each other and one is black while the other is brown and yellow.", "a close up of two luggage suit cases stacked on each other", "A stack of antique luggage is displayed with price tags.", "two suitcases made of leather and stacked on top of each other"], "instances": [{"category": "suitcase", "bbox": [0.167, 0.025, 0.989, 0.445]}, {"category": "suitcase", "bbox": [0.002, 0.31, 0.994, 0.996]}]}
59
+ {"id": "000000353536", "image": "000000353536.jpg", "captions": ["A table topped with plates and glasses with eating utensils..", "a fork is laying on a small white plate", "dirty dishes on a table, and a bottle of something.", "a table top with some dishes on top of it", "A table full of dirty dishes is pictured in this image."], "instances": [{"category": "dining table", "bbox": [0.0, 0.007, 0.998, 0.988]}, {"category": "bottle", "bbox": [0.554, 0.002, 0.768, 0.411]}, {"category": "cup", "bbox": [0.372, 0.011, 0.544, 0.427]}, {"category": "fork", "bbox": [0.442, 0.464, 0.818, 0.572]}, {"category": "fork", "bbox": [0.089, 0.233, 0.272, 0.456]}, {"category": "spoon", "bbox": [0.144, 0.218, 0.326, 0.413]}, {"category": "cup", "bbox": [0.688, 0.056, 0.812, 0.361]}]}
60
+ {"id": "000000416256", "image": "000000416256.jpg", "captions": ["A cat laying on the floor next to a keyboard.", "an orange and white cat is laying next to a keyboard and some wires", "A cat is laying next to a computer keyboard.", "a cat laying on a floor next to a keyboard", "A CAT LAYING ON THE FLOOR AMIDST A COMPUTER,SPEAKERS,CORDS"], "instances": [{"category": "cat", "bbox": [0.235, 0.23, 0.737, 0.639]}, {"category": "keyboard", "bbox": [0.243, 0.562, 0.631, 0.836]}, {"category": "keyboard", "bbox": [0.058, 0.33, 0.277, 0.608]}]}
61
+ {"id": "000000214367", "image": "000000214367.jpg", "captions": ["Wood shading on the side of a window with brick siding.", "A tree filled with lots of red fruit near a building.", "By the window outside is a apple tree, where the apples are ready to be picked.", "Some very nice looking red fruity by a window,", "A shuttered window has a fruit tree outside it."], "instances": [{"category": "apple", "bbox": [0.214, 0.112, 0.408, 0.266]}, {"category": "apple", "bbox": [0.472, 0.166, 0.618, 0.293]}, {"category": "apple", "bbox": [0.055, 0.592, 0.172, 0.686]}, {"category": "apple", "bbox": [0.126, 0.661, 0.236, 0.739]}, {"category": "apple", "bbox": [0.52, 0.09, 0.609, 0.143]}, {"category": "apple", "bbox": [0.226, 0.354, 0.285, 0.409]}, {"category": "apple", "bbox": [0.0, 0.698, 0.096, 0.771]}, {"category": "apple", "bbox": [0.001, 0.646, 0.042, 0.713]}, {"category": "apple", "bbox": [0.258, 0.719, 0.329, 0.778]}]}
62
+ {"id": "000000210299", "image": "000000210299.jpg", "captions": ["A little boy riding his bike and wearing a helmet", "A little boy raveling down a road on a bike, with a yellow helmet on.", "The boy wears a helmet while riding his bicycle.", "a small child wearing a helmet and riding a bike", "A little boy wearing a helmet and riding a bike."], "instances": [{"category": "person", "bbox": [0.198, 0.259, 0.399, 0.679]}, {"category": "bicycle", "bbox": [0.213, 0.383, 0.408, 0.835]}]}
63
+ {"id": "000000088218", "image": "000000088218.jpg", "captions": ["Signs proclaim the famous Haight Ashbury intersection and district.", "a pole with street lights, signs and wires attached to it", "A traffic light at the intersection of Haight and Ashbury", "A traffic sign is shown with traffic signs above it.", "The street signs and traffic signal are below wires attached to the pole."], "instances": [{"category": "traffic light", "bbox": [0.443, 0.435, 0.658, 0.721]}]}
64
+ {"id": "000000020650", "image": "000000020650.jpg", "captions": ["Burger with broccoli, pickle, and fork on orange plate", "On a plate is kept a burger and a bowl of broccoli and a fork.", "There is half a sandwich on an orange plate with a pickle and a bowl of broccoli", "A A bowl and a sandwich on an orange plate on a table.", "A plate has a sandwich, broccoli, and a pickle."], "instances": [{"category": "sandwich", "bbox": [0.436, 0.155, 0.805, 0.859]}, {"category": "sandwich", "bbox": [0.311, 0.006, 0.748, 0.293]}, {"category": "fork", "bbox": [0.0, 0.665, 0.578, 0.876]}, {"category": "bowl", "bbox": [0.002, 0.263, 0.487, 0.744]}, {"category": "bowl", "bbox": [0.708, 0.003, 0.828, 0.03]}, {"category": "broccoli", "bbox": [0.185, 0.288, 0.366, 0.546]}, {"category": "broccoli", "bbox": [0.017, 0.344, 0.384, 0.654]}, {"category": "broccoli", "bbox": [0.31, 0.191, 0.466, 0.463]}, {"category": "broccoli", "bbox": [0.104, 0.107, 0.285, 0.342]}, {"category": "broccoli", "bbox": [0.092, 0.276, 0.242, 0.442]}, {"category": "dining table", "bbox": [0.002, 0.0, 0.999, 0.987]}]}
65
+ {"id": "000000514915", "image": "000000514915.jpg", "captions": ["A large black dog laying on a kitchen floor.", "A dog is laying down on the floor in the home.", "Black dog laying down on the kitchen floor next to it's bowls and toy", "A black dog with a red collar laying on a tiled floor.", "A black dog that is laying on the floor."], "instances": [{"category": "dog", "bbox": [0.087, 0.276, 0.812, 0.792]}, {"category": "bowl", "bbox": [0.437, 0.09, 0.533, 0.213]}, {"category": "bowl", "bbox": [0.537, 0.035, 0.665, 0.141]}]}
66
+ {"id": "000000205183", "image": "000000205183.jpg", "captions": ["A duck walking along a paved road next to a patch of grass.", "A close up of a duck walking on a path.", "a duck walks along a cement patch while looking down", "A white duck out of water, walking on the ground.", "A goose standing in the road, looking at the ground."], "instances": [{"category": "bird", "bbox": [0.291, 0.235, 0.859, 0.889]}]}
67
+ {"id": "000000534270", "image": "000000534270.jpg", "captions": ["Man and woman with umbrella hats sitting on top of a bridge.", "A couple equipped with umbrella hats taking a break from walking their dog on a bridge on a rainy day.", "Two people in ridiculous looking umbrella hats.", "two people with umbrella hats near one another", "A couple of people wearing umbrella hats next to the ocean."], "instances": [{"category": "dog", "bbox": [0.456, 0.832, 0.6, 0.983]}, {"category": "person", "bbox": [0.433, 0.464, 0.636, 0.975]}, {"category": "person", "bbox": [0.263, 0.321, 0.459, 0.978]}, {"category": "boat", "bbox": [0.912, 0.4, 0.978, 0.433]}, {"category": "boat", "bbox": [0.211, 0.236, 0.478, 0.304]}, {"category": "boat", "bbox": [0.144, 0.328, 0.189, 0.361]}, {"category": "umbrella", "bbox": [0.443, 0.402, 0.607, 0.473]}, {"category": "umbrella", "bbox": [0.325, 0.311, 0.483, 0.432]}, {"category": "umbrella", "bbox": [0.207, 0.738, 0.284, 0.778]}, {"category": "umbrella", "bbox": [0.489, 0.713, 0.649, 0.83]}]}
68
+ {"id": "000000408439", "image": "000000408439.jpg", "captions": ["Cliffs rise on the edge of a placid lake.", "A scenic view of a river with a train on the edge of it in the distance.", "A large lake surrounded by beautiful tree covered mountains.", "a landscape scene with water, mountains and trees", "A train on a waterfront track surrounded by mountains."], "instances": [{"category": "train", "bbox": [0.008, 0.591, 0.562, 0.644]}]}
69
+ {"id": "000000474253", "image": "000000474253.jpg", "captions": ["A man riding on the back of a horse through a river.", "A person is riding a horse through water.", "Horse and rider crossing waterway during competitive event.", "A woman riding a horse splashes through a large puddle.", "A young man riding a horse through some water."], "instances": [{"category": "horse", "bbox": [0.385, 0.235, 0.651, 0.814]}, {"category": "person", "bbox": [0.396, 0.06, 0.576, 0.675]}, {"category": "person", "bbox": [0.29, 0.148, 0.355, 0.333]}, {"category": "person", "bbox": [0.129, 0.163, 0.212, 0.349]}, {"category": "person", "bbox": [0.005, 0.014, 0.038, 0.165]}, {"category": "person", "bbox": [0.144, 0.011, 0.193, 0.155]}, {"category": "person", "bbox": [0.089, 0.007, 0.133, 0.162]}]}
70
+ {"id": "000000098029", "image": "000000098029.jpg", "captions": ["a table with many plates on it with a bread basket", "A table set for four has many foods and fruits on it.", "Several objects displayed on a kitchen table including bread, oranges and plating.", "Several dishes and food items sit on a table.", "An assortment of foods sitting on a round brown table."], "instances": [{"category": "refrigerator", "bbox": [0.013, 0.004, 0.37, 0.317]}, {"category": "bottle", "bbox": [0.467, 0.517, 0.555, 0.638]}, {"category": "bottle", "bbox": [0.602, 0.536, 0.658, 0.609]}, {"category": "chair", "bbox": [0.747, 0.367, 1.0, 0.592]}, {"category": "chair", "bbox": [0.044, 0.368, 0.358, 0.544]}, {"category": "cup", "bbox": [0.296, 0.465, 0.359, 0.54]}, {"category": "cup", "bbox": [0.709, 0.67, 0.782, 0.736]}, {"category": "cup", "bbox": [0.213, 0.684, 0.294, 0.753]}, {"category": "knife", "bbox": [0.787, 0.699, 0.922, 0.797]}, {"category": "knife", "bbox": [0.161, 0.539, 0.265, 0.584]}, {"category": "spoon", "bbox": [0.813, 0.674, 0.922, 0.759]}, {"category": "spoon", "bbox": [0.156, 0.555, 0.233, 0.587]}, {"category": "spoon", "bbox": [0.596, 0.467, 0.613, 0.509]}, {"category": "bowl", "bbox": [0.241, 0.753, 0.505, 0.935]}, {"category": "banana", "bbox": [0.632, 0.138, 0.718, 0.161]}, {"category": "apple", "bbox": [0.701, 0.152, 0.758, 0.191]}, {"category": "orange", "bbox": [0.607, 0.66, 0.692, 0.716]}, {"category": "orange", "bbox": [0.565, 0.636, 0.611, 0.667]}, {"category": "orange", "bbox": [0.526, 0.624, 0.572, 0.652]}, {"category": "orange", "bbox": [0.61, 0.628, 0.656, 0.657]}, {"category": "orange", "bbox": [0.599, 0.649, 0.643, 0.677]}, {"category": "dining table", "bbox": [0.013, 0.439, 0.964, 0.986]}, {"category": "cup", "bbox": [0.612, 0.489, 0.669, 0.548]}, {"category": "knife", "bbox": [0.605, 0.457, 0.638, 0.53]}, {"category": "apple", "bbox": [0.502, 0.137, 0.537, 0.159]}, {"category": "orange", "bbox": [0.54, 0.135, 0.563, 0.151]}, {"category": "orange", "bbox": [0.527, 0.129, 0.554, 0.142]}, {"category": "orange", "bbox": [0.611, 0.155, 0.641, 0.171]}, {"category": "chair", "bbox": [0.0, 0.843, 0.29, 0.989]}, {"category": "cup", "bbox": [0.353, 0.469, 0.411, 0.511]}, {"category": "cup", "bbox": [0.609, 0.716, 0.682, 0.786]}, {"category": "orange", "bbox": [0.638, 0.158, 0.679, 0.177]}, {"category": "cake", "bbox": [0.38, 0.821, 0.481, 0.895]}, {"category": "chair", "bbox": [0.79, 0.747, 1.0, 1.0]}, {"category": "bottle", "bbox": [0.719, 0.55, 0.769, 0.616]}, {"category": "bottle", "bbox": [0.795, 0.546, 0.873, 0.613]}, {"category": "knife", "bbox": [0.17, 0.799, 0.264, 0.88]}, {"category": "cup", "bbox": [0.317, 0.695, 0.391, 0.752]}]}
71
+ {"id": "000000294073", "image": "000000294073.jpg", "captions": ["A woman and a man standing between two brown horses.", "A COUPLE WEARING YELLOW DRESS STANDING NEAR TWO HORSES.", "An older couple stands between two horses.", "A man and a woman standing with two horses", "A man and a woman stand in between two horses."], "instances": [{"category": "horse", "bbox": [0.0, 0.052, 0.49, 0.989]}, {"category": "horse", "bbox": [0.632, 0.23, 1.0, 0.989]}, {"category": "person", "bbox": [0.425, 0.326, 0.696, 0.987]}, {"category": "person", "bbox": [0.627, 0.203, 0.828, 0.986]}, {"category": "book", "bbox": [0.525, 0.597, 0.644, 0.833]}]}
72
+ {"id": "000000203629", "image": "000000203629.jpg", "captions": ["A man on a cell phone in a public area holding his thumb up.", "A group of people gathered inside of a room.", "A man on his cellphone posing for a picture.", "A man giving a thumbs up while on a cell phone.", "The man is giving a thumbs up while on his phone."], "instances": [{"category": "cell phone", "bbox": [0.43, 0.459, 0.449, 0.503]}, {"category": "cup", "bbox": [0.756, 0.838, 0.865, 0.98]}, {"category": "person", "bbox": [0.232, 0.317, 0.603, 0.98]}, {"category": "person", "bbox": [0.602, 0.405, 1.0, 0.999]}, {"category": "person", "bbox": [0.003, 0.339, 0.313, 0.987]}, {"category": "person", "bbox": [0.164, 0.379, 0.258, 0.733]}, {"category": "person", "bbox": [0.564, 0.36, 0.673, 0.645]}, {"category": "person", "bbox": [0.241, 0.379, 0.336, 0.512]}, {"category": "person", "bbox": [0.682, 0.372, 0.736, 0.502]}, {"category": "person", "bbox": [0.654, 0.428, 0.734, 0.536]}, {"category": "person", "bbox": [0.718, 0.368, 0.787, 0.508]}, {"category": "person", "bbox": [0.148, 0.362, 0.205, 0.529]}, {"category": "person", "bbox": [0.001, 0.431, 0.044, 0.564]}, {"category": "cup", "bbox": [0.901, 0.808, 0.995, 0.982]}]}
73
+ {"id": "000000119876", "image": "000000119876.jpg", "captions": ["A man dressed loudly is using his cell phone.", "A man talking on the phone while he walks down the street.", "A man with pink hair talking on a cell phone.", "A man in a purple shirt and tie and purple hair.", "a man colored his hair in purple walking on the road"], "instances": [{"category": "bicycle", "bbox": [0.525, 0.222, 0.924, 0.608]}, {"category": "bicycle", "bbox": [0.895, 0.249, 1.0, 0.642]}, {"category": "person", "bbox": [0.0, 0.0, 0.738, 1.0]}, {"category": "tie", "bbox": [0.319, 0.255, 0.423, 0.638]}, {"category": "cell phone", "bbox": [0.411, 0.13, 0.426, 0.161]}, {"category": "handbag", "bbox": [0.369, 0.205, 0.575, 0.839]}]}
74
+ {"id": "000000164255", "image": "000000164255.jpg", "captions": ["An umbrella that is standing in the sand.", "An umbrella is stuck in the sand on the beach.", "a colorful striped umbrella on the beach near the ocean", "A colorful umbrella is set up at the beach.", "The colorful umbrella is sitting by the beach,"], "instances": [{"category": "umbrella", "bbox": [0.0, 0.101, 0.567, 0.575]}]}
75
+ {"id": "000000192817", "image": "000000192817.jpg", "captions": ["A view from a window high up in the sky.", "A bunch of mountains seen from a plane window.", "The window from a plane overlooking the ground.", "The view of a mountain area from an airplane window.", "An aerial view of mountains and lakes from an airplane window."], "instances": []}
76
+ {"id": "000000258285", "image": "000000258285.jpg", "captions": ["Two large passenger jets flying over a beach filled with birds.", "A plane is flying over a bird filed lake", "Two airplanes are in the sky over blue water.", "An airplane landing over an airplane on the ground.", "A photo of two plans with water and birds surrounding it , one plane in the air one one the ground."], "instances": [{"category": "bird", "bbox": [0.507, 0.941, 0.536, 0.973]}, {"category": "bird", "bbox": [0.304, 0.933, 0.315, 0.95]}, {"category": "bird", "bbox": [0.129, 0.885, 0.143, 0.912]}, {"category": "bird", "bbox": [0.158, 0.851, 0.165, 0.87]}, {"category": "bird", "bbox": [0.404, 0.839, 0.429, 0.864]}, {"category": "bird", "bbox": [0.498, 0.833, 0.513, 0.861]}, {"category": "airplane", "bbox": [0.276, 0.085, 0.825, 0.316]}, {"category": "airplane", "bbox": [0.478, 0.252, 0.983, 0.495]}, {"category": "bird", "bbox": [0.552, 0.828, 0.564, 0.844]}, {"category": "bird", "bbox": [0.789, 0.812, 0.798, 0.836]}, {"category": "bird", "bbox": [0.927, 0.82, 0.936, 0.838]}, {"category": "bird", "bbox": [0.65, 0.828, 0.664, 0.849]}, {"category": "bird", "bbox": [0.752, 0.81, 0.763, 0.83]}, {"category": "bird", "bbox": [0.841, 0.817, 0.852, 0.828]}, {"category": "bird", "bbox": [0.292, 0.849, 0.311, 0.868]}, {"category": "bird", "bbox": [0.005, 0.727, 0.981, 0.998]}]}
77
+ {"id": "000000506483", "image": "000000506483.jpg", "captions": ["An art installation is placed by a street.", "People sit near a display of large artworks including an oversize bench and painted feline heads.", "Looking down on a giant rocking bench and large animal heads.", "An over sized wooden bench next to two massive animal art sculptures.", "artistic sculptures and images on a city street"], "instances": [{"category": "car", "bbox": [0.656, 0.939, 0.933, 1.0]}, {"category": "person", "bbox": [0.08, 0.664, 0.147, 0.805]}, {"category": "person", "bbox": [0.154, 0.646, 0.217, 0.821]}, {"category": "bench", "bbox": [0.316, 0.124, 0.951, 0.635]}, {"category": "backpack", "bbox": [0.062, 0.701, 0.097, 0.769]}, {"category": "person", "bbox": [0.0, 0.132, 0.031, 0.197]}]}
78
+ {"id": "000000502168", "image": "000000502168.jpg", "captions": ["a fleet of naval ships in the ocean", "A group of men on aircraft carrier with other boats in the distance.", "A large ship floating in the ocean next to other ships.", "Several men on a boat looking over the side.", "The men wear hardhats as they work on the aircraft carrier."], "instances": [{"category": "boat", "bbox": [0.634, 0.292, 1.0, 0.982]}, {"category": "person", "bbox": [0.675, 0.507, 0.736, 0.731]}, {"category": "person", "bbox": [0.684, 0.737, 0.817, 1.0]}, {"category": "person", "bbox": [0.803, 0.691, 0.883, 0.932]}, {"category": "person", "bbox": [0.741, 0.56, 0.798, 0.767]}, {"category": "person", "bbox": [0.924, 0.269, 0.951, 0.367]}, {"category": "boat", "bbox": [0.079, 0.171, 0.172, 0.231]}, {"category": "boat", "bbox": [0.863, 0.131, 0.961, 0.239]}, {"category": "boat", "bbox": [0.435, 0.288, 0.46, 0.313]}, {"category": "boat", "bbox": [0.591, 0.186, 0.605, 0.222]}, {"category": "person", "bbox": [0.451, 0.289, 0.455, 0.296]}, {"category": "person", "bbox": [0.446, 0.29, 0.451, 0.296]}, {"category": "person", "bbox": [0.872, 0.627, 0.957, 0.966]}, {"category": "person", "bbox": [0.44, 0.288, 0.446, 0.3]}]}
79
+ {"id": "000000319432", "image": "000000319432.jpg", "captions": ["Man holding two shirts with luggage and window", "A man holding clothes on a hanger with a suitcase in front of him.", "A man show a red and a white clothing hangers.", "A man holding his garment bags in both hands", "A man holding up some clothes in some hanger bags."], "instances": [{"category": "person", "bbox": [0.0, 0.092, 0.776, 0.852]}, {"category": "suitcase", "bbox": [0.153, 0.798, 0.587, 1.0]}]}
80
+ {"id": "000000131019", "image": "000000131019.jpg", "captions": ["Two zebras and two monkeys walking on the grass.", "Two giraffes and another animal are on green grass.", "A baboon and two zebras grazing on the savannah.", "A baboon and its baby eat by two zebras in the grass", "Monkey standing behind two zebras as they graze."], "instances": [{"category": "zebra", "bbox": [0.367, 0.258, 0.834, 0.646]}, {"category": "zebra", "bbox": [0.161, 0.13, 0.396, 0.375]}, {"category": "bird", "bbox": [0.309, 0.138, 0.34, 0.163]}]}
chatgpt.svg ADDED
cli.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from tinyllava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5
+ from tinyllava.conversation import conv_templates, SeparatorStyle
6
+ from tinyllava.model.builder import load_pretrained_model
7
+ from tinyllava.utils import disable_torch_init
8
+ from tinyllava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
9
+
10
+ from PIL import Image
11
+
12
+ import requests
13
+ from PIL import Image
14
+ from io import BytesIO
15
+ from transformers import TextStreamer
16
+
17
+
18
+ def load_image(image_file):
19
+ if image_file.startswith('http://') or image_file.startswith('https://'):
20
+ response = requests.get(image_file)
21
+ image = Image.open(BytesIO(response.content)).convert('RGB')
22
+ else:
23
+ image = Image.open(image_file).convert('RGB')
24
+ return image
25
+
26
+
27
+ def main(args):
28
+ # Model
29
+ disable_torch_init()
30
+
31
+ model_name = get_model_name_from_path(args.model_path)
32
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
33
+
34
+ if 'llama-2' in model_name.lower():
35
+ conv_mode = "llava_llama_2"
36
+ elif "v1" in model_name.lower():
37
+ conv_mode = "llava_v1"
38
+ elif "mpt" in model_name.lower():
39
+ conv_mode = "mpt"
40
+ elif 'phi' or '3.1b' in model_name.lower():
41
+ conv_mode = 'phi'
42
+ elif 'stablelm' or '2.0b' in model_name.lower():
43
+ conv_mode = 'phi'
44
+ elif 'tinyllama' or '1.5b' in model_name.lower():
45
+ conv_mode = 'v1'
46
+ elif 'qwen' in model_name.lower():
47
+ conv_mode = 'qwen'
48
+ else:
49
+ conv_mode = "llava_v0"
50
+
51
+
52
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
53
+ print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
54
+ else:
55
+ args.conv_mode = conv_mode
56
+
57
+ conv = conv_templates[args.conv_mode].copy()
58
+ if "mpt" in model_name.lower():
59
+ roles = ('user', 'assistant')
60
+ else:
61
+ roles = conv.roles
62
+
63
+ image = load_image(args.image_file)
64
+ # Similar operation in model_worker.py
65
+ image_tensor = process_images([image], image_processor, model.config)
66
+ if type(image_tensor) is list:
67
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
68
+ else:
69
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
70
+
71
+ while True:
72
+ try:
73
+ inp = input(f"{roles[0]}: ")
74
+ except EOFError:
75
+ inp = ""
76
+ if not inp:
77
+ print("exit...")
78
+ break
79
+
80
+ print(f"{roles[1]}: ", end="")
81
+
82
+ if image is not None:
83
+ # first message
84
+ if model.config.mm_use_im_start_end:
85
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
86
+ else:
87
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
88
+ conv.append_message(conv.roles[0], inp)
89
+ image = None
90
+ else:
91
+ # later messages
92
+ conv.append_message(conv.roles[0], inp)
93
+ conv.append_message(conv.roles[1], None)
94
+ prompt = conv.get_prompt()
95
+
96
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
97
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
98
+ keywords = [stop_str]
99
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
100
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
101
+
102
+ with torch.inference_mode():
103
+ output_ids = model.generate(
104
+ input_ids,
105
+ images=image_tensor,
106
+ do_sample=True if args.temperature > 0 else False,
107
+ temperature=args.temperature,
108
+ max_new_tokens=args.max_new_tokens,
109
+ streamer=streamer,
110
+ use_cache=True,
111
+ # stopping_criteria=[stopping_criteria]
112
+ )
113
+
114
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
115
+ conv.messages[-1][-1] = outputs
116
+
117
+ if args.debug:
118
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
119
+
120
+
121
+ if __name__ == "__main__":
122
+ parser = argparse.ArgumentParser()
123
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
124
+ parser.add_argument("--model-base", type=str, default=None)
125
+ parser.add_argument("--image-file", type=str, required=True)
126
+ parser.add_argument("--device", type=str, default="cuda")
127
+ parser.add_argument("--conv-mode", type=str, default=None)
128
+ parser.add_argument("--temperature", type=float, default=0.2)
129
+ parser.add_argument("--max-new-tokens", type=int, default=512)
130
+ parser.add_argument("--load-8bit", action="store_true")
131
+ parser.add_argument("--load-4bit", action="store_true")
132
+ parser.add_argument("--debug", action="store_true")
133
+ args = parser.parse_args()
134
+ main(args)
configuration_phi.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Phi model configuration"""
17
+
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26
+ "microsoft/phi-2": "https://huggingface.co/microsoft/phi-2/resolve/main/config.json",
27
+ }
28
+
29
+
30
+ class PhiConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi
33
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
34
+ defaults will yield a similar configuration to that of the Phi
35
+ [microsoft/phi-1](https://huggingface.co/microsoft/phi-1).
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 51200):
40
+ Vocabulary size of the Phi model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`PhiModel`].
42
+ hidden_size (`int`, *optional*, defaults to 2048):
43
+ Dimension of the hidden representations.
44
+ intermediate_size (`int`, *optional*, defaults to 8192):
45
+ Dimension of the MLP representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 24):
47
+ Number of hidden layers in the Transformer decoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 32):
49
+ Number of attention heads for each attention layer in the Transformer decoder.
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
59
+ Dropout probability for mlp outputs.
60
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
61
+ The dropout ratio for the embeddings.
62
+ attention_dropout (`float`, *optional*, defaults to 0.0):
63
+ The dropout ratio after computing the attention scores.
64
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`):
65
+ The non-linear activation function (function or string) in the decoder.
66
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
67
+ The maximum sequence length that this model might ever be used with. Phi-1 and Phi-1.5 supports up to 2048
68
+ tokens.
69
+ initializer_range (`float`, *optional*, defaults to 0.02):
70
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
71
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
72
+ The epsilon used by the rms normalization layers.
73
+ use_cache (`bool`, *optional*, defaults to `True`):
74
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
75
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
76
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
77
+ Whether to tie weight embeddings
78
+ rope_theta (`float`, *optional*, defaults to 10000.0):
79
+ The base period of the RoPE embeddings.
80
+ rope_scaling (`Dict`, *optional*):
81
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
82
+ strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
83
+ is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
84
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
85
+ these scaling strategies behave:
86
+ https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This
87
+ is an experimental feature, subject to breaking API changes in future versions.
88
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5):
89
+ Percentage of the query and keys which will have rotary embedding.
90
+ qk_layernorm (`bool`, *optional*, defaults to `False`):
91
+ Whether or not to normalize the Queries and Keys after projecting the hidden states.
92
+ bos_token_id (`int`, *optional*, defaults to 1):
93
+ Denotes beginning of sequences token id.
94
+ eos_token_id (`int`, *optional*, defaults to 2):
95
+ Denotes end of sequences token id.
96
+ Example:
97
+ ```python
98
+ >>> from transformers import PhiModel, PhiConfig
99
+ >>> # Initializing a Phi-1 style configuration
100
+ >>> configuration = PhiConfig.from_pretrained("microsoft/phi-1")
101
+ >>> # Initializing a model from the configuration
102
+ >>> model = PhiModel(configuration)
103
+ >>> # Accessing the model configuration
104
+ >>> configuration = model.config
105
+ ```"""
106
+
107
+ model_type = "phi"
108
+ keys_to_ignore_at_inference = ["past_key_values"]
109
+
110
+ def __init__(
111
+ self,
112
+ vocab_size=51200,
113
+ hidden_size=2048,
114
+ intermediate_size=8192,
115
+ num_hidden_layers=24,
116
+ num_attention_heads=32,
117
+ num_key_value_heads=None,
118
+ resid_pdrop=0.0,
119
+ embd_pdrop=0.0,
120
+ attention_dropout=0.0,
121
+ hidden_act="gelu_new",
122
+ max_position_embeddings=2048,
123
+ initializer_range=0.02,
124
+ layer_norm_eps=1e-5,
125
+ use_cache=True,
126
+ tie_word_embeddings=False,
127
+ rope_theta=10000.0,
128
+ rope_scaling=None,
129
+ partial_rotary_factor=0.5,
130
+ qk_layernorm=False,
131
+ bos_token_id=1,
132
+ eos_token_id=2,
133
+ **kwargs,
134
+ ):
135
+ self.vocab_size = vocab_size
136
+ self.hidden_size = hidden_size
137
+ self.intermediate_size = intermediate_size
138
+ self.num_hidden_layers = num_hidden_layers
139
+ self.num_attention_heads = num_attention_heads
140
+
141
+ if num_key_value_heads is None:
142
+ num_key_value_heads = num_attention_heads
143
+
144
+ self.num_key_value_heads = num_key_value_heads
145
+ self.resid_pdrop = resid_pdrop
146
+ self.embd_pdrop = embd_pdrop
147
+ self.attention_dropout = attention_dropout
148
+ self.hidden_act = hidden_act
149
+ self.max_position_embeddings = max_position_embeddings
150
+ self.initializer_range = initializer_range
151
+ self.layer_norm_eps = layer_norm_eps
152
+ self.use_cache = use_cache
153
+ self.rope_theta = rope_theta
154
+ self.rope_scaling = rope_scaling
155
+ self.partial_rotary_factor = partial_rotary_factor
156
+ self.qk_layernorm = qk_layernorm
157
+ self._rope_scaling_validation()
158
+
159
+ super().__init__(
160
+ bos_token_id=bos_token_id,
161
+ eos_token_id=eos_token_id,
162
+ tie_word_embeddings=tie_word_embeddings,
163
+ **kwargs,
164
+ )
165
+
166
+ # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
167
+ def _rope_scaling_validation(self):
168
+ """
169
+ Validate the `rope_scaling` configuration.
170
+ """
171
+ if self.rope_scaling is None:
172
+ return
173
+
174
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
175
+ raise ValueError(
176
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
177
+ f"got {self.rope_scaling}"
178
+ )
179
+ rope_scaling_type = self.rope_scaling.get("type", None)
180
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
181
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
182
+ raise ValueError(
183
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
184
+ )
185
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
186
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
configuration_qwen.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from transformers import PretrainedConfig
7
+
8
+
9
+ class QWenConfig(PretrainedConfig):
10
+ model_type = "qwen"
11
+ keys_to_ignore_at_inference = ["past_key_values"]
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size=151936,
16
+ hidden_size=4096,
17
+ num_hidden_layers=32,
18
+ num_attention_heads=32,
19
+ emb_dropout_prob=0.0,
20
+ attn_dropout_prob=0.0,
21
+ layer_norm_epsilon=1e-6,
22
+ initializer_range=0.02,
23
+ max_position_embeddings=8192,
24
+ scale_attn_weights=True,
25
+ use_cache=True,
26
+ bf16=False,
27
+ fp16=False,
28
+ fp32=False,
29
+ kv_channels=128,
30
+ rotary_pct=1.0,
31
+ rotary_emb_base=10000,
32
+ use_dynamic_ntk=True,
33
+ use_logn_attn=True,
34
+ use_flash_attn="auto",
35
+ intermediate_size=22016,
36
+ no_bias=True,
37
+ tie_word_embeddings=False,
38
+ use_cache_quantization=False,
39
+ use_cache_kernel=False,
40
+ softmax_in_fp32=False,
41
+ **kwargs,
42
+ ):
43
+ self.vocab_size = vocab_size
44
+ self.hidden_size = hidden_size
45
+ self.intermediate_size = intermediate_size
46
+ self.num_hidden_layers = num_hidden_layers
47
+ self.num_attention_heads = num_attention_heads
48
+ self.emb_dropout_prob = emb_dropout_prob
49
+ self.attn_dropout_prob = attn_dropout_prob
50
+ self.layer_norm_epsilon = layer_norm_epsilon
51
+ self.initializer_range = initializer_range
52
+ self.scale_attn_weights = scale_attn_weights
53
+ self.use_cache = use_cache
54
+ self.max_position_embeddings = max_position_embeddings
55
+ self.bf16 = bf16
56
+ self.fp16 = fp16
57
+ self.fp32 = fp32
58
+ self.kv_channels = kv_channels
59
+ self.rotary_pct = rotary_pct
60
+ self.rotary_emb_base = rotary_emb_base
61
+ self.use_dynamic_ntk = use_dynamic_ntk
62
+ self.use_logn_attn = use_logn_attn
63
+ self.use_flash_attn = use_flash_attn
64
+ self.no_bias = no_bias
65
+ self.use_cache_quantization = use_cache_quantization
66
+ self.use_cache_kernel = use_cache_kernel
67
+ self.softmax_in_fp32 = softmax_in_fp32
68
+ super().__init__(
69
+ tie_word_embeddings=tie_word_embeddings,
70
+ **kwargs
71
+ )
constants.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
13
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
controller.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import asyncio
7
+ import dataclasses
8
+ from enum import Enum, auto
9
+ import json
10
+ import logging
11
+ import time
12
+ from typing import List, Union
13
+ import threading
14
+
15
+ from fastapi import FastAPI, Request
16
+ from fastapi.responses import StreamingResponse
17
+ import numpy as np
18
+ import requests
19
+ import uvicorn
20
+
21
+ from tinyllava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
+ from tinyllava.utils import build_logger, server_error_msg
23
+
24
+
25
+ logger = build_logger("controller", "controller.log")
26
+
27
+
28
+ class DispatchMethod(Enum):
29
+ LOTTERY = auto()
30
+ SHORTEST_QUEUE = auto()
31
+
32
+ @classmethod
33
+ def from_str(cls, name):
34
+ if name == "lottery":
35
+ return cls.LOTTERY
36
+ elif name == "shortest_queue":
37
+ return cls.SHORTEST_QUEUE
38
+ else:
39
+ raise ValueError(f"Invalid dispatch method")
40
+
41
+
42
+ @dataclasses.dataclass
43
+ class WorkerInfo:
44
+ model_names: List[str]
45
+ speed: int
46
+ queue_length: int
47
+ check_heart_beat: bool
48
+ last_heart_beat: str
49
+
50
+
51
+ def heart_beat_controller(controller):
52
+ while True:
53
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54
+ controller.remove_stable_workers_by_expiration()
55
+
56
+
57
+ class Controller:
58
+ def __init__(self, dispatch_method: str):
59
+ # Dict[str -> WorkerInfo]
60
+ self.worker_info = {}
61
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62
+
63
+ self.heart_beat_thread = threading.Thread(
64
+ target=heart_beat_controller, args=(self,))
65
+ self.heart_beat_thread.start()
66
+
67
+ logger.info("Init controller")
68
+
69
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
70
+ worker_status: dict):
71
+ if worker_name not in self.worker_info:
72
+ logger.info(f"Register a new worker: {worker_name}")
73
+ else:
74
+ logger.info(f"Register an existing worker: {worker_name}")
75
+
76
+ if not worker_status:
77
+ worker_status = self.get_worker_status(worker_name)
78
+ if not worker_status:
79
+ return False
80
+
81
+ self.worker_info[worker_name] = WorkerInfo(
82
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83
+ check_heart_beat, time.time())
84
+
85
+ logger.info(f"Register done: {worker_name}, {worker_status}")
86
+ return True
87
+
88
+ def get_worker_status(self, worker_name: str):
89
+ try:
90
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
91
+ except requests.exceptions.RequestException as e:
92
+ logger.error(f"Get status fails: {worker_name}, {e}")
93
+ return None
94
+
95
+ if r.status_code != 200:
96
+ logger.error(f"Get status fails: {worker_name}, {r}")
97
+ return None
98
+
99
+ return r.json()
100
+
101
+ def remove_worker(self, worker_name: str):
102
+ del self.worker_info[worker_name]
103
+
104
+ def refresh_all_workers(self):
105
+ old_info = dict(self.worker_info)
106
+ self.worker_info = {}
107
+
108
+ for w_name, w_info in old_info.items():
109
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
110
+ logger.info(f"Remove stale worker: {w_name}")
111
+
112
+ def list_models(self):
113
+ model_names = set()
114
+
115
+ for w_name, w_info in self.worker_info.items():
116
+ model_names.update(w_info.model_names)
117
+
118
+ return list(model_names)
119
+
120
+ def get_worker_address(self, model_name: str):
121
+ if self.dispatch_method == DispatchMethod.LOTTERY:
122
+ worker_names = []
123
+ worker_speeds = []
124
+ for w_name, w_info in self.worker_info.items():
125
+ if model_name in w_info.model_names:
126
+ worker_names.append(w_name)
127
+ worker_speeds.append(w_info.speed)
128
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
129
+ norm = np.sum(worker_speeds)
130
+ if norm < 1e-4:
131
+ return ""
132
+ worker_speeds = worker_speeds / norm
133
+ if True: # Directly return address
134
+ pt = np.random.choice(np.arange(len(worker_names)),
135
+ p=worker_speeds)
136
+ worker_name = worker_names[pt]
137
+ return worker_name
138
+
139
+ # Check status before returning
140
+ while True:
141
+ pt = np.random.choice(np.arange(len(worker_names)),
142
+ p=worker_speeds)
143
+ worker_name = worker_names[pt]
144
+
145
+ if self.get_worker_status(worker_name):
146
+ break
147
+ else:
148
+ self.remove_worker(worker_name)
149
+ worker_speeds[pt] = 0
150
+ norm = np.sum(worker_speeds)
151
+ if norm < 1e-4:
152
+ return ""
153
+ worker_speeds = worker_speeds / norm
154
+ continue
155
+ return worker_name
156
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157
+ worker_names = []
158
+ worker_qlen = []
159
+ for w_name, w_info in self.worker_info.items():
160
+ if model_name in w_info.model_names:
161
+ worker_names.append(w_name)
162
+ worker_qlen.append(w_info.queue_length / w_info.speed)
163
+ if len(worker_names) == 0:
164
+ return ""
165
+ min_index = np.argmin(worker_qlen)
166
+ w_name = worker_names[min_index]
167
+ self.worker_info[w_name].queue_length += 1
168
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169
+ return w_name
170
+ else:
171
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172
+
173
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
174
+ if worker_name not in self.worker_info:
175
+ logger.info(f"Receive unknown heart beat. {worker_name}")
176
+ return False
177
+
178
+ self.worker_info[worker_name].queue_length = queue_length
179
+ self.worker_info[worker_name].last_heart_beat = time.time()
180
+ logger.info(f"Receive heart beat. {worker_name}")
181
+ return True
182
+
183
+ def remove_stable_workers_by_expiration(self):
184
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185
+ to_delete = []
186
+ for worker_name, w_info in self.worker_info.items():
187
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188
+ to_delete.append(worker_name)
189
+
190
+ for worker_name in to_delete:
191
+ self.remove_worker(worker_name)
192
+
193
+ def worker_api_generate_stream(self, params):
194
+ worker_addr = self.get_worker_address(params["model"])
195
+ if not worker_addr:
196
+ logger.info(f"no worker: {params['model']}")
197
+ ret = {
198
+ "text": server_error_msg,
199
+ "error_code": 2,
200
+ }
201
+ yield json.dumps(ret).encode() + b"\0"
202
+
203
+ try:
204
+ response = requests.post(worker_addr + "/worker_generate_stream",
205
+ json=params, stream=True, timeout=5)
206
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207
+ if chunk:
208
+ yield chunk + b"\0"
209
+ except requests.exceptions.RequestException as e:
210
+ logger.info(f"worker timeout: {worker_addr}")
211
+ ret = {
212
+ "text": server_error_msg,
213
+ "error_code": 3,
214
+ }
215
+ yield json.dumps(ret).encode() + b"\0"
216
+
217
+
218
+ # Let the controller act as a worker to achieve hierarchical
219
+ # management. This can be used to connect isolated sub networks.
220
+ def worker_api_get_status(self):
221
+ model_names = set()
222
+ speed = 0
223
+ queue_length = 0
224
+
225
+ for w_name in self.worker_info:
226
+ worker_status = self.get_worker_status(w_name)
227
+ if worker_status is not None:
228
+ model_names.update(worker_status["model_names"])
229
+ speed += worker_status["speed"]
230
+ queue_length += worker_status["queue_length"]
231
+
232
+ return {
233
+ "model_names": list(model_names),
234
+ "speed": speed,
235
+ "queue_length": queue_length,
236
+ }
237
+
238
+
239
+ app = FastAPI()
240
+
241
+
242
+ @app.post("/register_worker")
243
+ async def register_worker(request: Request):
244
+ data = await request.json()
245
+ controller.register_worker(
246
+ data["worker_name"], data["check_heart_beat"],
247
+ data.get("worker_status", None))
248
+
249
+
250
+ @app.post("/refresh_all_workers")
251
+ async def refresh_all_workers():
252
+ models = controller.refresh_all_workers()
253
+
254
+
255
+ @app.post("/list_models")
256
+ async def list_models():
257
+ models = controller.list_models()
258
+ return {"models": models}
259
+
260
+
261
+ @app.post("/get_worker_address")
262
+ async def get_worker_address(request: Request):
263
+ data = await request.json()
264
+ addr = controller.get_worker_address(data["model"])
265
+ return {"address": addr}
266
+
267
+
268
+ @app.post("/receive_heart_beat")
269
+ async def receive_heart_beat(request: Request):
270
+ data = await request.json()
271
+ exist = controller.receive_heart_beat(
272
+ data["worker_name"], data["queue_length"])
273
+ return {"exist": exist}
274
+
275
+
276
+ @app.post("/worker_generate_stream")
277
+ async def worker_api_generate_stream(request: Request):
278
+ params = await request.json()
279
+ generator = controller.worker_api_generate_stream(params)
280
+ return StreamingResponse(generator)
281
+
282
+
283
+ @app.post("/worker_get_status")
284
+ async def worker_api_get_status(request: Request):
285
+ return controller.worker_api_get_status()
286
+
287
+
288
+ if __name__ == "__main__":
289
+ parser = argparse.ArgumentParser()
290
+ parser.add_argument("--host", type=str, default="localhost")
291
+ parser.add_argument("--port", type=int, default=21001)
292
+ parser.add_argument("--dispatch-method", type=str, choices=[
293
+ "lottery", "shortest_queue"], default="shortest_queue")
294
+ args = parser.parse_args()
295
+ logger.info(f"args: {args}")
296
+
297
+ controller = Controller(args.dispatch_method)
298
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
conversation.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+ PLAIN = auto()
12
+ LLAMA_2 = auto()
13
+ TINY_LLAMA = auto()
14
+ QWEN_2 = auto()
15
+
16
+
17
+ @dataclasses.dataclass
18
+ class Conversation:
19
+ """A class that keeps all conversation history."""
20
+ system: str
21
+ roles: List[str]
22
+ messages: List[List[str]]
23
+ offset: int
24
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
25
+ sep: str = "###"
26
+ sep2: str = None
27
+ version: str = "Unknown"
28
+
29
+ skip_next: bool = False
30
+
31
+ def get_prompt(self):
32
+ messages = self.messages
33
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
34
+ messages = self.messages.copy()
35
+ init_role, init_msg = messages[0].copy()
36
+ init_msg = init_msg[0].replace("<image>", "").strip()
37
+ if 'mmtag' in self.version:
38
+ messages[0] = (init_role, init_msg)
39
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
40
+ messages.insert(1, (self.roles[1], "Received."))
41
+ else:
42
+ messages[0] = (init_role, "<image>\n" + init_msg)
43
+
44
+ if self.sep_style == SeparatorStyle.SINGLE:
45
+ ret = self.system + self.sep
46
+ for role, message in messages:
47
+ if message:
48
+ if type(message) is tuple:
49
+ message, _, _ = message
50
+ ret += role + ": " + message + self.sep
51
+ else:
52
+ ret += role + ":"
53
+ elif self.sep_style == SeparatorStyle.TWO:
54
+ seps = [self.sep, self.sep2]
55
+ ret = self.system + seps[0]
56
+ for i, (role, message) in enumerate(messages):
57
+ if message:
58
+ if type(message) is tuple:
59
+ message, _, _ = message
60
+ ret += role + ": " + message + seps[i % 2]
61
+ else:
62
+ ret += role + ":"
63
+ elif self.sep_style == SeparatorStyle.MPT:
64
+ ret = self.system + self.sep
65
+ for role, message in messages:
66
+ if message:
67
+ if type(message) is tuple:
68
+ message, _, _ = message
69
+ ret += role + message + self.sep
70
+ else:
71
+ ret += role
72
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
73
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
74
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
75
+ ret = ""
76
+
77
+ for i, (role, message) in enumerate(messages):
78
+ if i == 0:
79
+ assert message, "first message should not be none"
80
+ assert role == self.roles[0], "first message should come from user"
81
+ if message:
82
+ if type(message) is tuple:
83
+ message, _, _ = message
84
+ if i == 0: message = wrap_sys(self.system) + message
85
+ if i % 2 == 0:
86
+ message = wrap_inst(message)
87
+ ret += self.sep + message
88
+ else:
89
+ ret += " " + message + " " + self.sep2
90
+ else:
91
+ ret += ""
92
+ ret = ret.lstrip(self.sep)
93
+ elif self.sep_style == SeparatorStyle.TINY_LLAMA:
94
+ sep = "</s>"
95
+ wrap_sys = lambda msg: f"<|system|>\n{msg}\n"
96
+ wrap_user = lambda msg: f"<|user|>\n{msg}\n"
97
+ wrap_assistant = lambda msg: f"<|assistant|>\n{msg}"
98
+ ret = ""
99
+
100
+ for i, (role, message) in enumerate(messages):
101
+ if i == 0:
102
+ assert message, "first message should not be none"
103
+ assert role == self.roles[0], "first message should come from user"
104
+ if message:
105
+ if type(message) is tuple:
106
+ message, _, _ = message
107
+ if i % 2 == 0:
108
+ message = wrap_user(message)
109
+ if i == 0:
110
+ message = wrap_sys(self.system) + message
111
+ ret += self.sep + message
112
+ else:
113
+ message = wrap_assistant(message) + self.sep2
114
+ ret += message
115
+ else:
116
+ ret += "<|assistant|>\n"
117
+ ret = ret.lstrip(self.sep)
118
+ elif self.sep_style == SeparatorStyle.QWEN_2:
119
+ ret = self.system + self.sep
120
+ for role, message in messages:
121
+ if message:
122
+ if type(message) is tuple:
123
+ message, _, _ = message
124
+ ret += role + message + self.sep
125
+ else:
126
+ ret += role
127
+ elif self.sep_style == SeparatorStyle.PLAIN:
128
+ seps = [self.sep, self.sep2]
129
+ ret = self.system
130
+ for i, (role, message) in enumerate(messages):
131
+ if message:
132
+ if type(message) is tuple:
133
+ message, _, _ = message
134
+ ret += message + seps[i % 2]
135
+ else:
136
+ ret += ""
137
+ else:
138
+ raise ValueError(f"Invalid style: {self.sep_style}")
139
+
140
+ return ret
141
+
142
+ def append_message(self, role, message):
143
+ self.messages.append([role, message])
144
+
145
+ def get_images(self, return_pil=False):
146
+ images = []
147
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
148
+ if i % 2 == 0:
149
+ if type(msg) is tuple:
150
+ import base64
151
+ from io import BytesIO
152
+ from PIL import Image
153
+ msg, image, image_process_mode = msg
154
+ if image_process_mode == "Pad":
155
+ def expand2square(pil_img, background_color=(122, 116, 104)):
156
+ width, height = pil_img.size
157
+ if width == height:
158
+ return pil_img
159
+ elif width > height:
160
+ result = Image.new(pil_img.mode, (width, width), background_color)
161
+ result.paste(pil_img, (0, (width - height) // 2))
162
+ return result
163
+ else:
164
+ result = Image.new(pil_img.mode, (height, height), background_color)
165
+ result.paste(pil_img, ((height - width) // 2, 0))
166
+ return result
167
+ image = expand2square(image)
168
+ elif image_process_mode in ["Default", "Crop"]:
169
+ pass
170
+ elif image_process_mode == "Resize":
171
+ image = image.resize((336, 336))
172
+ else:
173
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
174
+ max_hw, min_hw = max(image.size), min(image.size)
175
+ aspect_ratio = max_hw / min_hw
176
+ max_len, min_len = 800, 400
177
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
178
+ longest_edge = int(shortest_edge * aspect_ratio)
179
+ W, H = image.size
180
+ if longest_edge != max(image.size):
181
+ if H > W:
182
+ H, W = longest_edge, shortest_edge
183
+ else:
184
+ H, W = shortest_edge, longest_edge
185
+ image = image.resize((W, H))
186
+ if return_pil:
187
+ images.append(image)
188
+ else:
189
+ buffered = BytesIO()
190
+ image.save(buffered, format="PNG")
191
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
192
+ images.append(img_b64_str)
193
+ return images
194
+
195
+ def to_gradio_chatbot(self):
196
+ ret = []
197
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
198
+ if i % 2 == 0:
199
+ if type(msg) is tuple:
200
+ import base64
201
+ from io import BytesIO
202
+ msg, image, image_process_mode = msg
203
+ max_hw, min_hw = max(image.size), min(image.size)
204
+ aspect_ratio = max_hw / min_hw
205
+ max_len, min_len = 800, 400
206
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
207
+ longest_edge = int(shortest_edge * aspect_ratio)
208
+ W, H = image.size
209
+ if H > W:
210
+ H, W = longest_edge, shortest_edge
211
+ else:
212
+ H, W = shortest_edge, longest_edge
213
+ image = image.resize((W, H))
214
+ buffered = BytesIO()
215
+ image.save(buffered, format="JPEG")
216
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
217
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
218
+ msg = img_str + msg.replace('<image>', '').strip()
219
+ ret.append([msg, None])
220
+ else:
221
+ ret.append([msg, None])
222
+ else:
223
+ ret[-1][-1] = msg
224
+ return ret
225
+
226
+ def copy(self):
227
+ return Conversation(
228
+ system=self.system,
229
+ roles=self.roles,
230
+ messages=[[x, y] for x, y in self.messages],
231
+ offset=self.offset,
232
+ sep_style=self.sep_style,
233
+ sep=self.sep,
234
+ sep2=self.sep2,
235
+ version=self.version)
236
+
237
+ def dict(self):
238
+ if len(self.get_images()) > 0:
239
+ return {
240
+ "system": self.system,
241
+ "roles": self.roles,
242
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
243
+ "offset": self.offset,
244
+ "sep": self.sep,
245
+ "sep2": self.sep2,
246
+ }
247
+ return {
248
+ "system": self.system,
249
+ "roles": self.roles,
250
+ "messages": self.messages,
251
+ "offset": self.offset,
252
+ "sep": self.sep,
253
+ "sep2": self.sep2,
254
+ }
255
+
256
+
257
+ conv_vicuna_v0 = Conversation(
258
+ system="A chat between a curious human and an artificial intelligence assistant. "
259
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
260
+ roles=("Human", "Assistant"),
261
+ messages=(
262
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
263
+ ("Assistant",
264
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
265
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
266
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
267
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
268
+ "renewable and non-renewable energy sources:\n"
269
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
270
+ "energy sources are finite and will eventually run out.\n"
271
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
272
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
273
+ "and other negative effects.\n"
274
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
275
+ "have lower operational costs than non-renewable sources.\n"
276
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
277
+ "locations than non-renewable sources.\n"
278
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
279
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
280
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
281
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
282
+ ),
283
+ offset=2,
284
+ sep_style=SeparatorStyle.SINGLE,
285
+ sep="###",
286
+ )
287
+
288
+ conv_vicuna_v1 = Conversation(
289
+ system="A chat between a curious user and an artificial intelligence assistant. "
290
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
291
+ roles=("USER", "ASSISTANT"),
292
+ version="v1",
293
+ messages=(),
294
+ offset=0,
295
+ sep_style=SeparatorStyle.TWO,
296
+ sep=" ",
297
+ sep2="</s>",
298
+ )
299
+
300
+ conv_llama_2 = Conversation(
301
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
302
+
303
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
304
+ roles=("USER", "ASSISTANT"),
305
+ version="llama_v2",
306
+ messages=(),
307
+ offset=0,
308
+ sep_style=SeparatorStyle.LLAMA_2,
309
+ sep="<s>",
310
+ sep2="</s>",
311
+ )
312
+
313
+ conv_llava_llama_2 = Conversation(
314
+ system="You are a helpful language and vision assistant. "
315
+ "You are able to understand the visual content that the user provides, "
316
+ "and assist the user with a variety of tasks using natural language.",
317
+ roles=("USER", "ASSISTANT"),
318
+ version="llama_v2",
319
+ messages=(),
320
+ offset=0,
321
+ sep_style=SeparatorStyle.LLAMA_2,
322
+ sep="<s>",
323
+ sep2="</s>",
324
+ )
325
+
326
+ conv_tiny_llava_tiny_llama = Conversation(
327
+ system="You are a helpful language and vision assistant. "
328
+ "You are able to understand the visual content that the user provides, "
329
+ "and assist the user with a variety of tasks using natural language.",
330
+ roles=("USER", "ASSISTANT"),
331
+ version="tiny_llama",
332
+ messages=(),
333
+ offset=0,
334
+ sep_style=SeparatorStyle.TINY_LLAMA,
335
+ sep="<s>",
336
+ sep2="</s>"
337
+ )
338
+
339
+
340
+ conv_mpt = Conversation(
341
+ system="""<|im_start|>system
342
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
343
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
344
+ version="mpt",
345
+ messages=(),
346
+ offset=0,
347
+ sep_style=SeparatorStyle.MPT,
348
+ sep="<|im_end|>",
349
+ )
350
+
351
+ conv_llava_plain = Conversation(
352
+ system="",
353
+ roles=("", ""),
354
+ messages=(
355
+ ),
356
+ version='plain',
357
+ offset=0,
358
+ sep_style=SeparatorStyle.PLAIN,
359
+ sep="\n",
360
+ )
361
+
362
+ conv_llava_v0 = Conversation(
363
+ system="A chat between a curious human and an artificial intelligence assistant. "
364
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
365
+ roles=("Human", "Assistant"),
366
+ messages=(
367
+ ),
368
+ offset=0,
369
+ sep_style=SeparatorStyle.SINGLE,
370
+ sep="###",
371
+ )
372
+
373
+ conv_llava_v0_mmtag = Conversation(
374
+ system="A chat between a curious user and an artificial intelligence assistant. "
375
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
376
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
377
+ roles=("Human", "Assistant"),
378
+ messages=(
379
+ ),
380
+ offset=0,
381
+ sep_style=SeparatorStyle.SINGLE,
382
+ sep="###",
383
+ version="v0_mmtag",
384
+ )
385
+
386
+ conv_llava_v1 = Conversation(
387
+ system="A chat between a curious human and an artificial intelligence assistant. "
388
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
389
+ roles=("USER", "ASSISTANT"),
390
+ version="v1",
391
+ messages=(),
392
+ offset=0,
393
+ sep_style=SeparatorStyle.TWO,
394
+ sep=" ",
395
+ sep2="</s>",
396
+ )
397
+
398
+ conv_llava_v1_mmtag = Conversation(
399
+ system="A chat between a curious user and an artificial intelligence assistant. "
400
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
401
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
402
+ roles=("USER", "ASSISTANT"),
403
+ messages=(),
404
+ offset=0,
405
+ sep_style=SeparatorStyle.TWO,
406
+ sep=" ",
407
+ sep2="</s>",
408
+ version="v1_mmtag",
409
+ )
410
+
411
+ conv_phi_v0 = Conversation(
412
+ system="A chat between a curious user and an artificial intelligence assistant. "
413
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
414
+ roles=("USER", "ASSISTANT"),
415
+ version="phi",
416
+ messages=(),
417
+ offset=0,
418
+ sep_style=SeparatorStyle.TWO,
419
+ sep=" ",
420
+ sep2="<|endoftext|>",
421
+ )
422
+
423
+ conv_stablelm = Conversation(
424
+ system="A chat between a curious user and an artificial intelligence assistant. "
425
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
426
+ roles=("USER", "ASSISTANT"),
427
+ version="stablelm",
428
+ messages=(),
429
+ offset=0,
430
+ sep_style=SeparatorStyle.TWO,
431
+ sep=" ",
432
+ sep2="<|endoftext|>",
433
+ )
434
+
435
+ conv_mistral_instruct = Conversation(
436
+ system="",
437
+ roles=("USER", "ASSISTANT"),
438
+ version="llama_v2",
439
+ messages=(),
440
+ offset=0,
441
+ sep_style=SeparatorStyle.LLAMA_2,
442
+ sep="",
443
+ sep2="</s>",
444
+ )
445
+
446
+ conv_chatml_direct = Conversation(
447
+ system="""<|im_start|>system
448
+ Answer the questions.""",
449
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
450
+ version="mpt",
451
+ messages=(),
452
+ offset=0,
453
+ sep_style=SeparatorStyle.MPT,
454
+ sep="<|im_end|>",
455
+ )
456
+
457
+ conv_qwen2 = Conversation(
458
+ system="<|im_start|>system\nYou are a helpful assistant",
459
+ roles=("<im_start>user\n", "<im_start>assistant\n"),
460
+ version="mpt",
461
+ messages=(),
462
+ offset=0,
463
+ sep_style=SeparatorStyle.MPT,
464
+ sep="<im_end>"
465
+ )
466
+
467
+ default_conversation = conv_vicuna_v1
468
+ conv_templates = {
469
+ "default": conv_vicuna_v0,
470
+ "v0": conv_vicuna_v0,
471
+ "v1": conv_vicuna_v1,
472
+ "vicuna_v1": conv_vicuna_v1,
473
+ "llama_2": conv_llama_2,
474
+
475
+ "plain": conv_llava_plain,
476
+ "v0_plain": conv_llava_plain,
477
+ "llava_v0": conv_llava_v0,
478
+ "v0_mmtag": conv_llava_v0_mmtag,
479
+ "llava_v1": conv_llava_v1,
480
+ "v1_mmtag": conv_llava_v1_mmtag,
481
+ "llava_llama_2": conv_llava_llama_2,
482
+
483
+ "mpt": conv_mpt,
484
+
485
+ "tiny_llama": conv_tiny_llava_tiny_llama,
486
+ "phi": conv_phi_v0,
487
+
488
+ # added by llava-1.6
489
+ "mistral_instruct": conv_mistral_instruct,
490
+ "chatml_direct": conv_chatml_direct,
491
+ "mistral_direct": conv_chatml_direct,
492
+ }
493
+
494
+
495
+ if __name__ == "__main__":
496
+ print(default_conversation.get_prompt())
convert_phi_weights_to_hf.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Weights conversion script for Phi
18
+
19
+ This script downloads both Phi-1 and Phi-1.5 checkpoints to "checkpoint_path" and then converts the weights to
20
+ HugfgingFace model's format and saves them in "pytorch_dump_folder_path".
21
+ """
22
+
23
+ import argparse
24
+ import gc
25
+ import os
26
+
27
+ import torch
28
+ from huggingface_hub import hf_hub_download
29
+
30
+ from modeling_phi import PhiConfig, PhiForCausalLM
31
+
32
+
33
+ _MODELS = {
34
+ "microsoft/phi-1": "https://huggingface.co/microsoft/phi-1/blob/main/pytorch_model.bin",
35
+ "microsoft/phi-1_5": "https://huggingface.co/microsoft/phi-1_5/blob/main/pytorch_model.bin",
36
+ }
37
+
38
+
39
+ PHI_MAPPING = {
40
+ "layers.0.wte.weight": "model.embed_tokens.weight",
41
+ "layers.25.linear.bias": "lm_head.bias",
42
+ "layers.25.linear.weight": "lm_head.weight",
43
+ "layers.25.ln.bias": "model.final_layernorm.bias",
44
+ "layers.25.ln.weight": "model.final_layernorm.weight",
45
+ "layers": "model.layers",
46
+ "ln": "input_layernorm",
47
+ "mixer": "self_attn",
48
+ "Wqkv": "query_key_value",
49
+ "out_proj": "dense",
50
+ }
51
+
52
+
53
+ def convert_weights(original_weights, mapping, config):
54
+ converted_weights = {}
55
+ original_weights_keys = sorted(original_weights.keys())
56
+
57
+ # we change names (1-24) -> layers(0-23) for Phi model layers
58
+ range_change = {
59
+ f"layers.{k}.": f"layers.{v}."
60
+ for k, v in zip(range(1, config.num_hidden_layers + 1), range(0, config.num_hidden_layers))
61
+ }
62
+
63
+ mapping.update(**range_change)
64
+
65
+ for original_weights_key in original_weights_keys:
66
+ new_key = original_weights_key
67
+
68
+ if "rotary_emb" in new_key:
69
+ continue
70
+
71
+ if "Wqkv" in new_key:
72
+ if "weight" in new_key:
73
+ weight = original_weights[new_key]
74
+ weights_shape = weight.shape
75
+ weight = (
76
+ weight.view(3, config.num_attention_heads, -1, config.hidden_size)
77
+ .transpose(0, 1)
78
+ .reshape(*weights_shape)
79
+ )
80
+ original_weights[new_key] = weight
81
+ elif "bias" in new_key:
82
+ bias = original_weights[new_key]
83
+ bias_shape = bias.shape
84
+ bias = bias.view(3, config.num_attention_heads, -1).transpose(0, 1).reshape(*bias_shape)
85
+ original_weights[new_key] = bias
86
+
87
+ for k, v in mapping.items():
88
+ if k in new_key:
89
+ new_key = new_key.replace(k, v)
90
+
91
+ converted_weights[new_key] = original_weights.pop(original_weights_key)
92
+
93
+ return converted_weights
94
+
95
+
96
+ def _download(url: str, root: str):
97
+ repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}"
98
+ filename = f"{url.split('/')[-1]}"
99
+ hf_hub_download(
100
+ repo_id=repo_id,
101
+ filename=filename,
102
+ force_filename=root,
103
+ local_dir_use_symlinks=False,
104
+ )
105
+
106
+
107
+ def convert_phi_weights(checkpoint_path, pytorch_dump_folder_path, use_cuda, save_weights_directly):
108
+ device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
109
+ for each_model_name, each_model_url in _MODELS.items():
110
+ converted_checkpoint = {}
111
+
112
+ model_path = os.path.join(checkpoint_path, each_model_name + "_" + each_model_url.split("/")[-1])
113
+ if not os.path.exists(model_path):
114
+ print(f"\n{each_model_name} was not found! Downloading it to {model_path}")
115
+ _download(url=each_model_url, root=model_path)
116
+ model_checkpoint = torch.load(model_path, map_location=device)
117
+ model_type = each_model_name.split("/")[1] # phi-1 or phi-1_5
118
+ config = PhiConfig.from_pretrained(f"susnato/{model_type}_dev")
119
+
120
+ # Converting the weights
121
+ converted_checkpoint.update(**convert_weights(model_checkpoint, PHI_MAPPING, config))
122
+
123
+ # Save either the whole model or the converted weights
124
+ if save_weights_directly:
125
+ save_weights_path = os.path.join(
126
+ pytorch_dump_folder_path, each_model_name.split("/")[-1] + "_" + each_model_url.split("/")[-1]
127
+ )
128
+ torch.save(converted_checkpoint, save_weights_path)
129
+ print(f"Model weights saved at {save_weights_path}!")
130
+
131
+ else:
132
+ model = PhiForCausalLM(config).to(device)
133
+ model.load_state_dict(converted_checkpoint, strict=True)
134
+ save_model_path = os.path.join(pytorch_dump_folder_path, model_type)
135
+ model.save_pretrained(save_model_path)
136
+ print(f"Model saved at {save_model_path}!")
137
+
138
+ # release GPU memory for the 2nd model if cuda was used.
139
+ del config, model
140
+
141
+ # release GPU memory for the 2nd model if cuda was used.
142
+ del model_checkpoint, converted_checkpoint
143
+ if use_cuda:
144
+ torch.cuda.empty_cache()
145
+ gc.collect()
146
+
147
+
148
+ if __name__ == "__main__":
149
+ parser = argparse.ArgumentParser()
150
+ # # Required parameters
151
+ parser.add_argument(
152
+ "--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)"
153
+ )
154
+ parser.add_argument(
155
+ "--pytorch_dump_folder_path",
156
+ default=None,
157
+ type=str,
158
+ help="Path to the output PyTorch model. (Please enter full path)",
159
+ )
160
+ parser.add_argument(
161
+ "--use_cuda",
162
+ default=False,
163
+ type=bool,
164
+ help="Whether to load the weights on GPU during conversion or not, False by default",
165
+ )
166
+ parser.add_argument(
167
+ "--save_weights_directly",
168
+ default=True,
169
+ type=bool,
170
+ help="Whether to save the weights directly after conversion or load the weight to the Phi model and then save "
171
+ "the Phi model along with weights. True by default",
172
+ )
173
+
174
+ args = parser.parse_args()
175
+ convert_phi_weights(args.checkpoint_path, args.pytorch_dump_folder_path, args.use_cuda, args.save_weights_directly)
cpp_kernels.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils import cpp_extension
2
+ import pathlib
3
+ import os
4
+ import subprocess
5
+
6
+ def _get_cuda_bare_metal_version(cuda_dir):
7
+ raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
8
+ universal_newlines=True)
9
+ output = raw_output.split()
10
+ release_idx = output.index("release") + 1
11
+ release = output[release_idx].split(".")
12
+ bare_metal_major = release[0]
13
+ bare_metal_minor = release[1][0]
14
+
15
+ return raw_output, bare_metal_major, bare_metal_minor
16
+
17
+ def _create_build_dir(buildpath):
18
+ try:
19
+ os.mkdir(buildpath)
20
+ except OSError:
21
+ if not os.path.isdir(buildpath):
22
+ print(f"Creation of the build directory {buildpath} failed")
23
+
24
+ # Check if cuda 11 is installed for compute capability 8.0
25
+ cc_flag = []
26
+ _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
27
+ if int(bare_metal_major) >= 11:
28
+ cc_flag.append('-gencode')
29
+ cc_flag.append('arch=compute_80,code=sm_80')
30
+ if int(bare_metal_minor) >= 7:
31
+ cc_flag.append('-gencode')
32
+ cc_flag.append('arch=compute_90,code=sm_90')
33
+
34
+ # Build path
35
+ srcpath = pathlib.Path(__file__).parent.absolute()
36
+ buildpath = srcpath / 'build'
37
+ _create_build_dir(buildpath)
38
+
39
+ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
40
+ return cpp_extension.load(
41
+ name=name,
42
+ sources=sources,
43
+ build_directory=buildpath,
44
+ extra_cflags=['-O3', ],
45
+ extra_cuda_cflags=['-O3',
46
+ '-gencode', 'arch=compute_70,code=sm_70',
47
+ '--use_fast_math'] + extra_cuda_flags + cc_flag,
48
+ verbose=1
49
+ )
50
+
51
+ extra_flags = []
52
+
53
+ cache_autogptq_cuda_256_sources = ["./cache_autogptq_cuda_256.cpp",
54
+ "./cache_autogptq_cuda_kernel_256.cu"]
55
+ cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags)
eval_gpt_review.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import openai
6
+ import tqdm
7
+ import ray
8
+ import time
9
+
10
+ NUM_SECONDS_TO_SLEEP = 3
11
+
12
+ @ray.remote(num_cpus=4)
13
+ def get_eval(content: str, max_tokens: int):
14
+ while True:
15
+ try:
16
+ response = openai.ChatCompletion.create(
17
+ model='gpt-4',
18
+ messages=[{
19
+ 'role': 'system',
20
+ 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
21
+ }, {
22
+ 'role': 'user',
23
+ 'content': content,
24
+ }],
25
+ temperature=0.2, # TODO: figure out which temperature is best for evaluation
26
+ max_tokens=max_tokens,
27
+ )
28
+ break
29
+ except openai.error.RateLimitError:
30
+ pass
31
+ except Exception as e:
32
+ print(e)
33
+ time.sleep(NUM_SECONDS_TO_SLEEP)
34
+
35
+ print('success!')
36
+ return response['choices'][0]['message']['content']
37
+
38
+
39
+ def parse_score(review):
40
+ try:
41
+ score_pair = review.split('\n')[0]
42
+ score_pair = score_pair.replace(',', ' ')
43
+ sp = score_pair.split(' ')
44
+ if len(sp) == 2:
45
+ return [float(sp[0]), float(sp[1])]
46
+ else:
47
+ print('error', review)
48
+ return [-1, -1]
49
+ except Exception as e:
50
+ print(e)
51
+ print('error', review)
52
+ return [-1, -1]
53
+
54
+
55
+ if __name__ == '__main__':
56
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
57
+ parser.add_argument('-q', '--question')
58
+ # parser.add_argument('-a', '--answer')
59
+ parser.add_argument('-a', '--answer-list', nargs='+', default=[])
60
+ parser.add_argument('-r', '--rule')
61
+ parser.add_argument('-o', '--output')
62
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
63
+ args = parser.parse_args()
64
+
65
+ ray.init()
66
+
67
+ f_q = open(os.path.expanduser(args.question))
68
+ f_ans1 = open(os.path.expanduser(args.answer_list[0]))
69
+ f_ans2 = open(os.path.expanduser(args.answer_list[1]))
70
+ rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
71
+
72
+ review_file = open(f'{args.output}', 'w')
73
+
74
+ js_list = []
75
+ handles = []
76
+ idx = 0
77
+ for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
78
+ # if idx == 1:
79
+ # break
80
+
81
+ ques = json.loads(ques_js)
82
+ ans1 = json.loads(ans1_js)
83
+ ans2 = json.loads(ans2_js)
84
+
85
+ category = json.loads(ques_js)['category']
86
+ if category in rule_dict:
87
+ rule = rule_dict[category]
88
+ else:
89
+ rule = rule_dict['default']
90
+ prompt = rule['prompt']
91
+ role = rule['role']
92
+ content = (f'[Question]\n{ques["text"]}\n\n'
93
+ f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
94
+ f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
95
+ f'[System]\n{prompt}\n\n')
96
+ js_list.append({
97
+ 'id': idx+1,
98
+ 'question_id': ques['question_id'],
99
+ 'answer1_id': ans1['answer_id'],
100
+ 'answer2_id': ans2['answer_id'],
101
+ 'category': category})
102
+ idx += 1
103
+ handles.append(get_eval.remote(content, args.max_tokens))
104
+ # To avoid the rate limit set by OpenAI
105
+ time.sleep(NUM_SECONDS_TO_SLEEP)
106
+
107
+ reviews = ray.get(handles)
108
+ for idx, review in enumerate(reviews):
109
+ scores = parse_score(review)
110
+ js_list[idx]['content'] = review
111
+ js_list[idx]['tuple'] = scores
112
+ review_file.write(json.dumps(js_list[idx]) + '\n')
113
+ review_file.close()
eval_gpt_review_bench.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import openai
6
+ import time
7
+
8
+ NUM_SECONDS_TO_SLEEP = 0.5
9
+
10
+
11
+ def get_eval(content: str, max_tokens: int):
12
+ while True:
13
+ try:
14
+ response = openai.ChatCompletion.create(
15
+ model='gpt-4',
16
+ messages=[{
17
+ 'role': 'system',
18
+ 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19
+ }, {
20
+ 'role': 'user',
21
+ 'content': content,
22
+ }],
23
+ temperature=0.2, # TODO: figure out which temperature is best for evaluation
24
+ max_tokens=max_tokens,
25
+ )
26
+ break
27
+ except openai.error.RateLimitError:
28
+ pass
29
+ except Exception as e:
30
+ print(e)
31
+ time.sleep(NUM_SECONDS_TO_SLEEP)
32
+
33
+ return response['choices'][0]['message']['content']
34
+
35
+
36
+ def parse_score(review):
37
+ try:
38
+ score_pair = review.split('\n')[0]
39
+ score_pair = score_pair.replace(',', ' ')
40
+ sp = score_pair.split(' ')
41
+ if len(sp) == 2:
42
+ return [float(sp[0]), float(sp[1])]
43
+ else:
44
+ print('error', review)
45
+ return [-1, -1]
46
+ except Exception as e:
47
+ print(e)
48
+ print('error', review)
49
+ return [-1, -1]
50
+
51
+
52
+ if __name__ == '__main__':
53
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54
+ parser.add_argument('-q', '--question')
55
+ parser.add_argument('-c', '--context')
56
+ parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57
+ parser.add_argument('-r', '--rule')
58
+ parser.add_argument('-o', '--output')
59
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60
+ args = parser.parse_args()
61
+
62
+ f_q = open(os.path.expanduser(args.question))
63
+ f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64
+ f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65
+ rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66
+
67
+ if os.path.isfile(os.path.expanduser(args.output)):
68
+ cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69
+ else:
70
+ cur_reviews = []
71
+
72
+ review_file = open(f'{args.output}', 'a')
73
+
74
+ context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75
+ image_to_context = {context['image']: context for context in context_list}
76
+
77
+ handles = []
78
+ idx = 0
79
+ for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80
+ ques = json.loads(ques_js)
81
+ ans1 = json.loads(ans1_js)
82
+ ans2 = json.loads(ans2_js)
83
+
84
+ inst = image_to_context[ques['image']]
85
+
86
+ if isinstance(inst['caption'], list):
87
+ cap_str = '\n'.join(inst['caption'])
88
+ else:
89
+ cap_str = inst['caption']
90
+
91
+ category = 'llava_bench_' + json.loads(ques_js)['category']
92
+ if category in rule_dict:
93
+ rule = rule_dict[category]
94
+ else:
95
+ assert False, f"Visual QA category not found in rule file: {category}."
96
+ prompt = rule['prompt']
97
+ role = rule['role']
98
+ content = (f'[Context]\n{cap_str}\n\n'
99
+ f'[Question]\n{ques["text"]}\n\n'
100
+ f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
101
+ f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
102
+ f'[System]\n{prompt}\n\n')
103
+ cur_js = {
104
+ 'id': idx+1,
105
+ 'question_id': ques['question_id'],
106
+ 'answer1_id': ans1.get('answer_id', ans1['question_id']),
107
+ 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
108
+ 'category': category
109
+ }
110
+ if idx >= len(cur_reviews):
111
+ review = get_eval(content, args.max_tokens)
112
+ scores = parse_score(review)
113
+ cur_js['content'] = review
114
+ cur_js['tuple'] = scores
115
+ review_file.write(json.dumps(cur_js) + '\n')
116
+ review_file.flush()
117
+ else:
118
+ print(f'Skipping {idx} as we already have it.')
119
+ idx += 1
120
+ print(idx)
121
+ review_file.close()
eval_gpt_review_visual.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import openai
6
+ import time
7
+
8
+ NUM_SECONDS_TO_SLEEP = 0.5
9
+
10
+
11
+ def get_eval(content: str, max_tokens: int):
12
+ while True:
13
+ try:
14
+ response = openai.ChatCompletion.create(
15
+ model='gpt-4-0314',
16
+ messages=[{
17
+ 'role': 'system',
18
+ 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19
+ }, {
20
+ 'role': 'user',
21
+ 'content': content,
22
+ }],
23
+ temperature=0.2, # TODO: figure out which temperature is best for evaluation
24
+ max_tokens=max_tokens,
25
+ )
26
+ break
27
+ except openai.error.RateLimitError:
28
+ pass
29
+ except Exception as e:
30
+ print(e)
31
+ time.sleep(NUM_SECONDS_TO_SLEEP)
32
+
33
+ return response['choices'][0]['message']['content']
34
+
35
+
36
+ def parse_score(review):
37
+ try:
38
+ score_pair = review.split('\n')[0]
39
+ score_pair = score_pair.replace(',', ' ')
40
+ sp = score_pair.split(' ')
41
+ if len(sp) == 2:
42
+ return [float(sp[0]), float(sp[1])]
43
+ else:
44
+ print('error', review)
45
+ return [-1, -1]
46
+ except Exception as e:
47
+ print(e)
48
+ print('error', review)
49
+ return [-1, -1]
50
+
51
+
52
+ if __name__ == '__main__':
53
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54
+ parser.add_argument('-q', '--question')
55
+ parser.add_argument('-c', '--context')
56
+ parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57
+ parser.add_argument('-r', '--rule')
58
+ parser.add_argument('-o', '--output')
59
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60
+ args = parser.parse_args()
61
+
62
+ f_q = open(os.path.expanduser(args.question))
63
+ f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64
+ f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65
+ rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66
+
67
+ if os.path.isfile(os.path.expanduser(args.output)):
68
+ cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69
+ else:
70
+ cur_reviews = []
71
+
72
+ review_file = open(f'{args.output}', 'a')
73
+
74
+ context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75
+ image_to_context = {context['image']: context for context in context_list}
76
+
77
+ handles = []
78
+ idx = 0
79
+ for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80
+ ques = json.loads(ques_js)
81
+ ans1 = json.loads(ans1_js)
82
+ ans2 = json.loads(ans2_js)
83
+
84
+ inst = image_to_context[ques['image']]
85
+ cap_str = '\n'.join(inst['captions'])
86
+ box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
87
+
88
+ category = json.loads(ques_js)['category']
89
+ if category in rule_dict:
90
+ rule = rule_dict[category]
91
+ else:
92
+ assert False, f"Visual QA category not found in rule file: {category}."
93
+ prompt = rule['prompt']
94
+ role = rule['role']
95
+ content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
96
+ f'[Question]\n{ques["text"]}\n\n'
97
+ f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
98
+ f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
99
+ f'[System]\n{prompt}\n\n')
100
+ cur_js = {
101
+ 'id': idx+1,
102
+ 'question_id': ques['question_id'],
103
+ 'answer1_id': ans1.get('answer_id', ans1['question_id']),
104
+ 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
105
+ 'category': category
106
+ }
107
+ if idx >= len(cur_reviews):
108
+ review = get_eval(content, args.max_tokens)
109
+ scores = parse_score(review)
110
+ cur_js['content'] = review
111
+ cur_js['tuple'] = scores
112
+ review_file.write(json.dumps(cur_js) + '\n')
113
+ review_file.flush()
114
+ else:
115
+ print(f'Skipping {idx} as we already have it.')
116
+ idx += 1
117
+ print(idx)
118
+ review_file.close()
eval_pope.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+
5
+ def eval_pope(answers, label_file):
6
+ label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]
7
+
8
+ for answer in answers:
9
+ text = answer['text']
10
+
11
+ # Only keep the first sentence
12
+ if text.find('.') != -1:
13
+ text = text.split('.')[0]
14
+
15
+ text = text.replace(',', '')
16
+ words = text.split(' ')
17
+ if 'No' in words or 'not' in words or 'no' in words:
18
+ answer['text'] = 'no'
19
+ else:
20
+ answer['text'] = 'yes'
21
+
22
+ for i in range(len(label_list)):
23
+ if label_list[i] == 'no':
24
+ label_list[i] = 0
25
+ else:
26
+ label_list[i] = 1
27
+
28
+ pred_list = []
29
+ for answer in answers:
30
+ if answer['text'] == 'no':
31
+ pred_list.append(0)
32
+ else:
33
+ pred_list.append(1)
34
+
35
+ pos = 1
36
+ neg = 0
37
+ yes_ratio = pred_list.count(1) / len(pred_list)
38
+
39
+ TP, TN, FP, FN = 0, 0, 0, 0
40
+ for pred, label in zip(pred_list, label_list):
41
+ if pred == pos and label == pos:
42
+ TP += 1
43
+ elif pred == pos and label == neg:
44
+ FP += 1
45
+ elif pred == neg and label == neg:
46
+ TN += 1
47
+ elif pred == neg and label == pos:
48
+ FN += 1
49
+
50
+ print('TP\tFP\tTN\tFN\t')
51
+ print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))
52
+
53
+ precision = float(TP) / float(TP + FP)
54
+ recall = float(TP) / float(TP + FN)
55
+ f1 = 2*precision*recall / (precision + recall)
56
+ acc = (TP + TN) / (TP + TN + FP + FN)
57
+ print('Accuracy: {}'.format(acc))
58
+ print('Precision: {}'.format(precision))
59
+ print('Recall: {}'.format(recall))
60
+ print('F1 score: {}'.format(f1))
61
+ print('Yes ratio: {}'.format(yes_ratio))
62
+ print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) )
63
+
64
+ if __name__ == "__main__":
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument("--annotation-dir", type=str)
67
+ parser.add_argument("--question-file", type=str)
68
+ parser.add_argument("--result-file", type=str)
69
+ args = parser.parse_args()
70
+
71
+ questions = [json.loads(line) for line in open(args.question_file)]
72
+ questions = {question['question_id']: question for question in questions}
73
+ answers = [json.loads(q) for q in open(args.result_file)]
74
+ for file in os.listdir(args.annotation_dir):
75
+ assert file.startswith('coco_pope_')
76
+ assert file.endswith('.json')
77
+ category = file[10:-5]
78
+ cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category]
79
+ print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
80
+ eval_pope(cur_answers, os.path.join(args.annotation_dir, file))
81
+ print("====================================")
eval_science_qa.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import random
6
+
7
+
8
+ def get_args():
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--base-dir', type=str)
11
+ parser.add_argument('--result-file', type=str)
12
+ parser.add_argument('--output-file', type=str)
13
+ parser.add_argument('--output-result', type=str)
14
+ parser.add_argument('--split', type=str, default='test')
15
+ parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16
+ return parser.parse_args()
17
+
18
+
19
+ def convert_caps(results):
20
+ fakecaps = []
21
+ for result in results:
22
+ image_id = result['question_id']
23
+ caption = result['text']
24
+ fakecaps.append({"image_id": int(image_id), "caption": caption})
25
+ return fakecaps
26
+
27
+
28
+ def get_pred_idx(prediction, choices, options):
29
+ """
30
+ Get the index (e.g. 2) from the prediction (e.g. 'C')
31
+ """
32
+ if prediction in options[:len(choices)]:
33
+ return options.index(prediction)
34
+ else:
35
+ return -1
36
+ return random.choice(range(len(choices)))
37
+
38
+
39
+ if __name__ == "__main__":
40
+ args = get_args()
41
+
42
+ base_dir = args.base_dir
43
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
44
+ problems = json.load(open(os.path.join(base_dir, "problems.json")))
45
+ predictions = [json.loads(line) for line in open(args.result_file)]
46
+ predictions = {pred['question_id']: pred for pred in predictions}
47
+ split_problems = {idx: problems[idx] for idx in split_indices}
48
+
49
+ results = {'correct': [], 'incorrect': []}
50
+ sqa_results = {}
51
+ sqa_results['acc'] = None
52
+ sqa_results['correct'] = None
53
+ sqa_results['count'] = None
54
+ sqa_results['results'] = {}
55
+ sqa_results['outputs'] = {}
56
+
57
+ for prob_id, prob in split_problems.items():
58
+ if prob_id not in predictions:
59
+ pred = {'text': 'FAILED', 'prompt': 'Unknown'}
60
+ pred_text = 'FAILED'
61
+ else:
62
+ pred = predictions[prob_id]
63
+ pred_text = pred['text']
64
+
65
+ if pred_text in args.options:
66
+ answer = pred_text
67
+ elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
68
+ answer = pred_text[0]
69
+ else:
70
+ pattern = re.compile(r'The answer is ([A-Z]).')
71
+ res = pattern.findall(pred_text)
72
+ if len(res) == 1:
73
+ answer = res[0] # 'A', 'B', ...
74
+ else:
75
+ answer = "FAILED"
76
+
77
+ pred_idx = get_pred_idx(answer, prob['choices'], args.options)
78
+
79
+ analysis = {
80
+ 'question_id': prob_id,
81
+ 'parsed_ans': answer,
82
+ 'ground_truth': args.options[prob['answer']],
83
+ 'question': pred['prompt'],
84
+ 'pred': pred_text,
85
+ 'is_multimodal': '<image>' in pred['prompt'],
86
+ }
87
+
88
+ sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
89
+ sqa_results['outputs'][prob_id] = pred_text
90
+
91
+ if pred_idx == prob['answer']:
92
+ results['correct'].append(analysis)
93
+ else:
94
+ results['incorrect'].append(analysis)
95
+
96
+ correct = len(results['correct'])
97
+ total = len(results['correct']) + len(results['incorrect'])
98
+
99
+ ###### IMG ######
100
+ multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
101
+ multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
102
+ multimodal_total = multimodal_correct + multimodal_incorrect
103
+ ###### IMG ######
104
+
105
+ print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
106
+
107
+ sqa_results['acc'] = correct / total * 100
108
+ sqa_results['correct'] = correct
109
+ sqa_results['count'] = total
110
+
111
+ with open(args.output_file, 'w') as f:
112
+ json.dump(results, f, indent=2)
113
+ with open(args.output_result, 'w') as f:
114
+ json.dump(sqa_results, f, indent=2)
eval_textvqa.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import json
4
+ import re
5
+
6
+ from tinyllava.eval.m4c_evaluator import TextVQAAccuracyEvaluator
7
+
8
+
9
+ def get_args():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--annotation-file', type=str)
12
+ parser.add_argument('--result-file', type=str)
13
+ parser.add_argument('--result-dir', type=str)
14
+ return parser.parse_args()
15
+
16
+
17
+ def prompt_processor(prompt):
18
+ if prompt.startswith('OCR tokens: '):
19
+ pattern = r"Question: (.*?) Short answer:"
20
+ match = re.search(pattern, prompt, re.DOTALL)
21
+ question = match.group(1)
22
+ elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
23
+ if prompt.startswith('Reference OCR token:'):
24
+ question = prompt.split('\n')[1]
25
+ else:
26
+ question = prompt.split('\n')[0]
27
+ elif len(prompt.split('\n')) == 2:
28
+ question = prompt.split('\n')[0]
29
+ else:
30
+ assert False
31
+
32
+ return question.lower()
33
+
34
+
35
+ def eval_single(annotation_file, result_file):
36
+ experiment_name = os.path.splitext(os.path.basename(result_file))[0]
37
+ print(experiment_name)
38
+ annotations = json.load(open(annotation_file))['data']
39
+ annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
40
+ results = [json.loads(line) for line in open(result_file)]
41
+
42
+ pred_list = []
43
+ for result in results:
44
+ annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
45
+ pred_list.append({
46
+ "pred_answer": result['text'],
47
+ "gt_answers": annotation['answers'],
48
+ })
49
+
50
+ evaluator = TextVQAAccuracyEvaluator()
51
+ print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
52
+
53
+
54
+ if __name__ == "__main__":
55
+ args = get_args()
56
+
57
+ if args.result_file is not None:
58
+ eval_single(args.annotation_file, args.result_file)
59
+
60
+ if args.result_dir is not None:
61
+ for result_file in sorted(os.listdir(args.result_dir)):
62
+ if not result_file.endswith('.jsonl'):
63
+ print(f'Skipping {result_file}')
64
+ continue
65
+ eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))
extreme_ironing.jpg ADDED
generate_webpage_data_from_table.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate json file for webpage."""
2
+ import json
3
+ import os
4
+ import re
5
+
6
+ # models = ['llama', 'alpaca', 'gpt35', 'bard']
7
+ models = ['vicuna']
8
+
9
+
10
+ def read_jsonl(path: str, key: str=None):
11
+ data = []
12
+ with open(os.path.expanduser(path)) as f:
13
+ for line in f:
14
+ if not line:
15
+ continue
16
+ data.append(json.loads(line))
17
+ if key is not None:
18
+ data.sort(key=lambda x: x[key])
19
+ data = {item[key]: item for item in data}
20
+ return data
21
+
22
+
23
+ def trim_hanging_lines(s: str, n: int) -> str:
24
+ s = s.strip()
25
+ for _ in range(n):
26
+ s = s.split('\n', 1)[1].strip()
27
+ return s
28
+
29
+
30
+ if __name__ == '__main__':
31
+ questions = read_jsonl('table/question.jsonl', key='question_id')
32
+
33
+ # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
34
+ # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
35
+ # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
36
+ # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
37
+ vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')
38
+ ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')
39
+
40
+ review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')
41
+ # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
42
+ # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
43
+ # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
44
+ # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id')
45
+
46
+ records = []
47
+ for qid in questions.keys():
48
+ r = {
49
+ 'id': qid,
50
+ 'category': questions[qid]['category'],
51
+ 'question': questions[qid]['text'],
52
+ 'answers': {
53
+ # 'alpaca': alpaca_answers[qid]['text'],
54
+ # 'llama': llama_answers[qid]['text'],
55
+ # 'bard': bard_answers[qid]['text'],
56
+ # 'gpt35': gpt35_answers[qid]['text'],
57
+ 'vicuna': vicuna_answers[qid]['text'],
58
+ 'ours': ours_answers[qid]['text'],
59
+ },
60
+ 'evaluations': {
61
+ # 'alpaca': review_alpaca[qid]['text'],
62
+ # 'llama': review_llama[qid]['text'],
63
+ # 'bard': review_bard[qid]['text'],
64
+ 'vicuna': review_vicuna[qid]['content'],
65
+ # 'gpt35': review_gpt35[qid]['text'],
66
+ },
67
+ 'scores': {
68
+ 'vicuna': review_vicuna[qid]['tuple'],
69
+ # 'alpaca': review_alpaca[qid]['score'],
70
+ # 'llama': review_llama[qid]['score'],
71
+ # 'bard': review_bard[qid]['score'],
72
+ # 'gpt35': review_gpt35[qid]['score'],
73
+ },
74
+ }
75
+
76
+ # cleanup data
77
+ cleaned_evals = {}
78
+ for k, v in r['evaluations'].items():
79
+ v = v.strip()
80
+ lines = v.split('\n')
81
+ # trim the first line if it's a pair of numbers
82
+ if re.match(r'\d+[, ]+\d+', lines[0]):
83
+ lines = lines[1:]
84
+ v = '\n'.join(lines)
85
+ cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**')
86
+
87
+ r['evaluations'] = cleaned_evals
88
+ records.append(r)
89
+
90
+ # Reorder the records, this is optional
91
+ for r in records:
92
+ if r['id'] <= 20:
93
+ r['id'] += 60
94
+ else:
95
+ r['id'] -= 20
96
+ for r in records:
97
+ if r['id'] <= 50:
98
+ r['id'] += 10
99
+ elif 50 < r['id'] <= 60:
100
+ r['id'] -= 50
101
+ for r in records:
102
+ if r['id'] == 7:
103
+ r['id'] = 1
104
+ elif r['id'] < 7:
105
+ r['id'] += 1
106
+
107
+ records.sort(key=lambda x: x['id'])
108
+
109
+ # Write to file
110
+ with open('webpage/data.json', 'w') as f:
111
+ json.dump({'questions': records, 'models': models}, f, indent=2)
gradio_web_server.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+
7
+ import gradio as gr
8
+ import requests
9
+
10
+ from tinyllava.conversation import (default_conversation, conv_templates,
11
+ SeparatorStyle)
12
+ from tinyllava.constants import LOGDIR
13
+ from tinyllava.utils import (build_logger, server_error_msg,
14
+ violates_moderation, moderation_msg)
15
+ import hashlib
16
+
17
+
18
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
19
+
20
+ headers = {"User-Agent": "LLaVA Client"}
21
+
22
+ no_change_btn = gr.Button.update()
23
+ enable_btn = gr.Button.update(interactive=True)
24
+ disable_btn = gr.Button.update(interactive=False)
25
+
26
+ priority = {
27
+ "vicuna-13b": "aaaaaaa",
28
+ "koala-13b": "aaaaaab",
29
+ }
30
+
31
+
32
+ def get_conv_log_filename():
33
+ t = datetime.datetime.now()
34
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
35
+ return name
36
+
37
+
38
+ def get_model_list():
39
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
40
+ assert ret.status_code == 200
41
+ ret = requests.post(args.controller_url + "/list_models")
42
+ models = ret.json()["models"]
43
+ models.sort(key=lambda x: priority.get(x, x))
44
+ logger.info(f"Models: {models}")
45
+ return models
46
+
47
+
48
+ get_window_url_params = """
49
+ function() {
50
+ const params = new URLSearchParams(window.location.search);
51
+ url_params = Object.fromEntries(params);
52
+ console.log(url_params);
53
+ return url_params;
54
+ }
55
+ """
56
+
57
+
58
+ def load_demo(url_params, request: gr.Request):
59
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
60
+
61
+ dropdown_update = gr.Dropdown.update(visible=True)
62
+ if "model" in url_params:
63
+ model = url_params["model"]
64
+ if model in models:
65
+ dropdown_update = gr.Dropdown.update(
66
+ value=model, visible=True)
67
+
68
+ state = default_conversation.copy()
69
+ return state, dropdown_update
70
+
71
+
72
+ def load_demo_refresh_model_list(request: gr.Request):
73
+ logger.info(f"load_demo. ip: {request.client.host}")
74
+ models = get_model_list()
75
+ state = default_conversation.copy()
76
+ dropdown_update = gr.Dropdown.update(
77
+ choices=models,
78
+ value=models[0] if len(models) > 0 else ""
79
+ )
80
+ return state, dropdown_update
81
+
82
+
83
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
84
+ with open(get_conv_log_filename(), "a") as fout:
85
+ data = {
86
+ "tstamp": round(time.time(), 4),
87
+ "type": vote_type,
88
+ "model": model_selector,
89
+ "state": state.dict(),
90
+ "ip": request.client.host,
91
+ }
92
+ fout.write(json.dumps(data) + "\n")
93
+
94
+
95
+ def upvote_last_response(state, model_selector, request: gr.Request):
96
+ logger.info(f"upvote. ip: {request.client.host}")
97
+ vote_last_response(state, "upvote", model_selector, request)
98
+ return ("",) + (disable_btn,) * 3
99
+
100
+
101
+ def downvote_last_response(state, model_selector, request: gr.Request):
102
+ logger.info(f"downvote. ip: {request.client.host}")
103
+ vote_last_response(state, "downvote", model_selector, request)
104
+ return ("",) + (disable_btn,) * 3
105
+
106
+
107
+ def flag_last_response(state, model_selector, request: gr.Request):
108
+ logger.info(f"flag. ip: {request.client.host}")
109
+ vote_last_response(state, "flag", model_selector, request)
110
+ return ("",) + (disable_btn,) * 3
111
+
112
+
113
+ def regenerate(state, image_process_mode, request: gr.Request):
114
+ logger.info(f"regenerate. ip: {request.client.host}")
115
+ state.messages[-1][-1] = None
116
+ prev_human_msg = state.messages[-2]
117
+ if type(prev_human_msg[1]) in (tuple, list):
118
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
119
+ state.skip_next = False
120
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
121
+
122
+
123
+ def clear_history(request: gr.Request):
124
+ logger.info(f"clear_history. ip: {request.client.host}")
125
+ state = default_conversation.copy()
126
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
127
+
128
+
129
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
130
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
131
+ if len(text) <= 0 and image is None:
132
+ state.skip_next = True
133
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
134
+ if args.moderate:
135
+ flagged = violates_moderation(text)
136
+ if flagged:
137
+ state.skip_next = True
138
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
139
+ no_change_btn,) * 5
140
+
141
+ text = text[:1536] # Hard cut-off
142
+ if image is not None:
143
+ text = text[:1200] # Hard cut-off for images
144
+ if '<image>' not in text:
145
+ # text = '<Image><image></Image>' + text
146
+ text = text + '\n<image>'
147
+ text = (text, image, image_process_mode)
148
+ if len(state.get_images(return_pil=True)) > 0:
149
+ state = default_conversation.copy()
150
+ state.append_message(state.roles[0], text)
151
+ state.append_message(state.roles[1], None)
152
+ state.skip_next = False
153
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
154
+
155
+
156
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
157
+ logger.info(f"http_bot. ip: {request.client.host}")
158
+ start_tstamp = time.time()
159
+ model_name = model_selector
160
+
161
+ if state.skip_next:
162
+ # This generate call is skipped due to invalid inputs
163
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
164
+ return
165
+
166
+ if len(state.messages) == state.offset + 2:
167
+ # First round of conversation
168
+ if "tinyllava" in model_name.lower():
169
+ if 'llama-2' in model_name.lower():
170
+ template_name = "llava_llama_2"
171
+ elif "v1" in model_name.lower():
172
+ if 'mmtag' in model_name.lower():
173
+ template_name = "v1_mmtag"
174
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
175
+ template_name = "v1_mmtag"
176
+ else:
177
+ template_name = "llava_v1"
178
+ elif 'phi' in model_name.lower():
179
+ template_name = "phi"
180
+ elif "mpt" in model_name.lower():
181
+ template_name = "mpt"
182
+ else:
183
+ if 'mmtag' in model_name.lower():
184
+ template_name = "v0_mmtag"
185
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
186
+ template_name = "v0_mmtag"
187
+ else:
188
+ template_name = "llava_v0"
189
+ elif "mpt" in model_name:
190
+ template_name = "mpt_text"
191
+ elif "llama-2" in model_name:
192
+ template_name = "llama_2"
193
+ else:
194
+ template_name = "vicuna_v1"
195
+ new_state = conv_templates[template_name].copy()
196
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
197
+ new_state.append_message(new_state.roles[1], None)
198
+ state = new_state
199
+
200
+ # Query worker address
201
+ controller_url = args.controller_url
202
+ ret = requests.post(controller_url + "/get_worker_address",
203
+ json={"model": model_name})
204
+ worker_addr = ret.json()["address"]
205
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
206
+
207
+ # No available worker
208
+ if worker_addr == "":
209
+ state.messages[-1][-1] = server_error_msg
210
+ yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
211
+ return
212
+
213
+ # Construct prompt
214
+ prompt = state.get_prompt()
215
+
216
+ all_images = state.get_images(return_pil=True)
217
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
218
+ for image, hash in zip(all_images, all_image_hash):
219
+ t = datetime.datetime.now()
220
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
221
+ if not os.path.isfile(filename):
222
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
223
+ image.save(filename)
224
+
225
+ # Make requests
226
+ pload = {
227
+ "model": model_name,
228
+ "prompt": prompt,
229
+ "temperature": float(temperature),
230
+ "top_p": float(top_p),
231
+ "max_new_tokens": min(int(max_new_tokens), 1536),
232
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
233
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
234
+ }
235
+ logger.info(f"==== request ====\n{pload}")
236
+
237
+ pload['images'] = state.get_images()
238
+
239
+ state.messages[-1][-1] = "▌"
240
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
241
+
242
+ try:
243
+ # Stream output
244
+ response = requests.post(worker_addr + "/worker_generate_stream",
245
+ headers=headers, json=pload, stream=True, timeout=10)
246
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
247
+ if chunk:
248
+ data = json.loads(chunk.decode())
249
+ if data["error_code"] == 0:
250
+ output = data["text"][len(prompt):].strip()
251
+ state.messages[-1][-1] = output + "▌"
252
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
253
+ else:
254
+ output = data["text"] + f" (error_code: {data['error_code']})"
255
+ state.messages[-1][-1] = output
256
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
257
+ return
258
+ time.sleep(0.03)
259
+ except requests.exceptions.RequestException as e:
260
+ state.messages[-1][-1] = server_error_msg
261
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
262
+ return
263
+
264
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
265
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
266
+
267
+ finish_tstamp = time.time()
268
+ logger.info(f"{output}")
269
+
270
+ with open(get_conv_log_filename(), "a") as fout:
271
+ data = {
272
+ "tstamp": round(finish_tstamp, 4),
273
+ "type": "chat",
274
+ "model": model_name,
275
+ "start": round(start_tstamp, 4),
276
+ "finish": round(finish_tstamp, 4),
277
+ "state": state.dict(),
278
+ "images": all_image_hash,
279
+ "ip": request.client.host,
280
+ }
281
+ fout.write(json.dumps(data) + "\n")
282
+
283
+ title_markdown = ("""
284
+ # 🌋 LLaVA: Large Language and Vision Assistant
285
+ [[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)]
286
+ """)
287
+
288
+ tos_markdown = ("""
289
+ ### Terms of use
290
+ By using this service, users are required to agree to the following terms:
291
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
292
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
293
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
294
+ """)
295
+
296
+
297
+ learn_more_markdown = ("""
298
+ ### License
299
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
300
+ """)
301
+
302
+ block_css = """
303
+
304
+ #buttons button {
305
+ min-width: min(120px,100%);
306
+ }
307
+
308
+ """
309
+
310
+ def build_demo(embed_mode):
311
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
312
+ with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
313
+ state = gr.State()
314
+
315
+ if not embed_mode:
316
+ gr.Markdown(title_markdown)
317
+
318
+ with gr.Row():
319
+ with gr.Column(scale=3):
320
+ with gr.Row(elem_id="model_selector_row"):
321
+ model_selector = gr.Dropdown(
322
+ choices=models,
323
+ value=models[0] if len(models) > 0 else "",
324
+ interactive=True,
325
+ show_label=False,
326
+ container=False)
327
+
328
+ imagebox = gr.Image(type="pil")
329
+ image_process_mode = gr.Radio(
330
+ ["Crop", "Resize", "Pad", "Default"],
331
+ value="Default",
332
+ label="Preprocess for non-square image", visible=False)
333
+
334
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
335
+ gr.Examples(examples=[
336
+ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
337
+ [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
338
+ ], inputs=[imagebox, textbox])
339
+
340
+ with gr.Accordion("Parameters", open=False) as parameter_row:
341
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
342
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
343
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
344
+
345
+ with gr.Column(scale=8):
346
+ chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
347
+ with gr.Row():
348
+ with gr.Column(scale=8):
349
+ textbox.render()
350
+ with gr.Column(scale=1, min_width=50):
351
+ submit_btn = gr.Button(value="Send", variant="primary")
352
+ with gr.Row(elem_id="buttons") as button_row:
353
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
354
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
355
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
356
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
357
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
358
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
359
+
360
+ if not embed_mode:
361
+ gr.Markdown(tos_markdown)
362
+ gr.Markdown(learn_more_markdown)
363
+ url_params = gr.JSON(visible=False)
364
+
365
+ # Register listeners
366
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
367
+ upvote_btn.click(
368
+ upvote_last_response,
369
+ [state, model_selector],
370
+ [textbox, upvote_btn, downvote_btn, flag_btn],
371
+ queue=False
372
+ )
373
+ downvote_btn.click(
374
+ downvote_last_response,
375
+ [state, model_selector],
376
+ [textbox, upvote_btn, downvote_btn, flag_btn],
377
+ queue=False
378
+ )
379
+ flag_btn.click(
380
+ flag_last_response,
381
+ [state, model_selector],
382
+ [textbox, upvote_btn, downvote_btn, flag_btn],
383
+ queue=False
384
+ )
385
+
386
+ regenerate_btn.click(
387
+ regenerate,
388
+ [state, image_process_mode],
389
+ [state, chatbot, textbox, imagebox] + btn_list,
390
+ queue=False
391
+ ).then(
392
+ http_bot,
393
+ [state, model_selector, temperature, top_p, max_output_tokens],
394
+ [state, chatbot] + btn_list
395
+ )
396
+
397
+ clear_btn.click(
398
+ clear_history,
399
+ None,
400
+ [state, chatbot, textbox, imagebox] + btn_list,
401
+ queue=False
402
+ )
403
+
404
+ textbox.submit(
405
+ add_text,
406
+ [state, textbox, imagebox, image_process_mode],
407
+ [state, chatbot, textbox, imagebox] + btn_list,
408
+ queue=False
409
+ ).then(
410
+ http_bot,
411
+ [state, model_selector, temperature, top_p, max_output_tokens],
412
+ [state, chatbot] + btn_list
413
+ )
414
+
415
+ submit_btn.click(
416
+ add_text,
417
+ [state, textbox, imagebox, image_process_mode],
418
+ [state, chatbot, textbox, imagebox] + btn_list,
419
+ queue=False
420
+ ).then(
421
+ http_bot,
422
+ [state, model_selector, temperature, top_p, max_output_tokens],
423
+ [state, chatbot] + btn_list
424
+ )
425
+
426
+ if args.model_list_mode == "once":
427
+ demo.load(
428
+ load_demo,
429
+ [url_params],
430
+ [state, model_selector],
431
+ _js=get_window_url_params,
432
+ queue=False
433
+ )
434
+ elif args.model_list_mode == "reload":
435
+ demo.load(
436
+ load_demo_refresh_model_list,
437
+ None,
438
+ [state, model_selector],
439
+ queue=False
440
+ )
441
+ else:
442
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
443
+
444
+ return demo
445
+
446
+
447
+ if __name__ == "__main__":
448
+ parser = argparse.ArgumentParser()
449
+ parser.add_argument("--host", type=str, default="127.0.0.1")
450
+ parser.add_argument("--port", type=int, default=6006)
451
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
452
+ parser.add_argument("--concurrency-count", type=int, default=10)
453
+ parser.add_argument("--model-list-mode", type=str, default="once",
454
+ choices=["once", "reload"])
455
+ parser.add_argument("--share", action="store_true")
456
+ parser.add_argument("--moderate", action="store_true")
457
+ parser.add_argument("--embed", action="store_true")
458
+ args = parser.parse_args()
459
+ logger.info(f"args: {args}")
460
+
461
+ models = get_model_list()
462
+
463
+ logger.info(args)
464
+ demo = build_demo(args.embed)
465
+ demo.queue(
466
+ concurrency_count=args.concurrency_count,
467
+ api_open=False
468
+ ).launch(
469
+ server_name=args.host,
470
+ server_port=args.port,
471
+ share=args.share
472
+ )
index.html ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots</title>
7
+ <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css">
8
+ <link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
9
+ <link rel="stylesheet" href="styles.css">
10
+ </head>
11
+
12
+ <body>
13
+ <nav class="navbar navbar-expand-lg navbar-dark bg-dark">
14
+ <a class="navbar-brand" href="#">🏔️ Vicuna Evaluation Examples</a>
15
+ <button class="navbar-toggler" type="button" data-toggle="collapse" data-target="#navbarNav" aria-controls="navbarNav" aria-expanded="false" aria-label="Toggle navigation">
16
+ <span class="navbar-toggler-icon"></span>
17
+ </button>
18
+ <div class="collapse navbar-collapse" id="navbarNav">
19
+ <ul class="navbar-nav mr-auto">
20
+ <li class="nav-item">
21
+ <a class="nav-link" href="https://chat.lmsys.org/">Demo</a>
22
+ </li>
23
+ <li class="nav-item">
24
+ <a class="nav-link" href="https://vicuna.lmsys.org">Blog</a>
25
+ </li>
26
+ <li class="nav-item">
27
+ <a class="nav-link" href="https://github.com/lm-sys/FastChat">Github</a>
28
+ </li>
29
+ </ul>
30
+ </div>
31
+ </nav>
32
+
33
+ <div class="container mt-5">
34
+ <h2 class="text-center mb-5">Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots</h2>
35
+
36
+ <!-- Selection -->
37
+ <div class="form-row">
38
+ <div class="form-group col-md-2">
39
+ <label for="category-select">Category</label>
40
+ <select class="form-control" id="category-select"></select>
41
+ </div>
42
+ <div class="form-group col-md-8">
43
+ <label for="question-select">Question</label>
44
+ <select class="form-control" id="question-select"></select>
45
+ </div>
46
+ <div class="form-group col-md-2">
47
+ <div class="col-md-2"><label>&nbsp;</label></div>
48
+ <div class="btn-group" role="group" aria-label="Left and Right Controller">
49
+ <button type="button" class="form-control btn btn-primary" id="prev-question"><i class="material-icons">keyboard_arrow_left</i></button>
50
+ <button type="button" class="form-control btn btn-primary" id="next-question"><i class="material-icons">keyboard_arrow_right</i></button>
51
+ </div>
52
+ </div>
53
+ </div>
54
+
55
+ <!-- "Battle" -->
56
+ <div class="row mb-4" style="justify-content: center;">
57
+ <div class="col" style="display: flex; justify-content: center; align-items: center;">
58
+ <label class="adjustable-font-size" id="other-score-label">*/10</label>
59
+ </div>
60
+ <div class="col">
61
+ <div class="vertical-flex-layout">
62
+ <img class="shadow figure-img img-fluid" src="" alt="other logo" width="150" id="other-model-figure">
63
+ </div>
64
+ </div>
65
+ <div class="col">
66
+ <div class="vertical-flex-layout">
67
+ <!-- from: https://fonts.google.com/icons?icon.query=battle&selected=Material+Symbols+Outlined:swords:FILL@0;wght@300;GRAD@0;opsz@48&icon.style=Outlined -->
68
+ <img class="figure-img img-fluid" src="figures/swords_FILL0_wght300_GRAD0_opsz48.svg" width="60" height="60">
69
+ </div>
70
+ </div>
71
+ <div class="col">
72
+ <div class="vertical-flex-layout">
73
+ <img class="shadow figure-img img-fluid" src="figures/vicuna.jpeg" alt="vicuna logo" width="150" id="our-model-figure">
74
+ </div>
75
+ </div>
76
+ <div class="col" style="display: flex; justify-content: center; align-items: center;">
77
+ <label class="adjustable-font-size" id="our-score-label">*/10</label>
78
+ </div>
79
+ </div>
80
+
81
+ <!-- Question Card -->
82
+ <div class="card mb-4">
83
+ <div class="card-body" id="selected-question"></div>
84
+ </div>
85
+
86
+ <!-- Answer Cards -->
87
+ <div class="row">
88
+ <div class="col-md-6">
89
+ <div class="card mb-4 expandable-card">
90
+ <div class="card-header" style="padding-bottom: 0.2rem" id="other-model-header-bg">
91
+ <div class="row">
92
+ <div class="col-md-5" style="align-items: center; display: flex;">
93
+ <label id="other-model-header">Assistant #1</label>
94
+ </div>
95
+ <div class="col-md-7">
96
+ <select class="form-control" id="model-select" style="height: fit-content; margin-top: -0.3rem;"></select>
97
+ </div>
98
+ </div>
99
+ </div>
100
+ <div class="card-body">
101
+ <div class="card-text-container">
102
+ <div class="card-text" id="other-model-answer"></div>
103
+ </div>
104
+ <div class="btn btn-primary expand-btn" style="display:flex;"></div>
105
+ </div>
106
+ </div>
107
+ </div>
108
+ <div class="col-md-6">
109
+ <div class="card mb-4 expandable-card">
110
+ <div class="card-header" id="our-model-header">
111
+ Assistant #2 (Vicuna, our model)
112
+ </div>
113
+ <div class="card-body">
114
+ <div class="card-text-container">
115
+ <div class="card-text" id="our-model-answer"></div>
116
+ </div>
117
+ <div class="btn btn-primary expand-btn" style="display:flex;"></div>
118
+ </div>
119
+ </div>
120
+ </div>
121
+ </div>
122
+
123
+ <!-- Evaluation -->
124
+ <div class="card expandable-card">
125
+ <div class="card-header" style="background-color: #c9c9f2;" id="evaluation-header">GPT-4 Evaluation</div>
126
+ <div class="card-body">
127
+ <div class="card-text-container">
128
+ <div class="card-text" id="evaluation-result"></div>
129
+ </div>
130
+ <div class="btn btn-primary expand-btn" style="display:flex;"></div>
131
+ </div>
132
+ </div>
133
+ </div>
134
+
135
+ <div class="container-fluid bg-light py-2">
136
+ <div class="text-center">
137
+ <small class="text-muted">This website is co-authored with <a href="https://openai.com" target="_blank">GPT-4</a>.</small>
138
+ </div>
139
+ </div>
140
+
141
+ <!-- Marked.js -->
142
+ <script src="https://cdn.jsdelivr.net/npm/marked@4.3.0/lib/marked.umd.min.js"></script>
143
+ <!-- Bootstrap and Popper.js JavaScript dependencies -->
144
+ <script src="https://code.jquery.com/jquery-3.5.1.slim.min.js"></script>
145
+ <script src="https://cdn.jsdelivr.net/npm/@popperjs/core@2.11.6/dist/umd/popper.min.js"></script>
146
+ <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.5.2/js/bootstrap.min.js"></script>
147
+
148
+ <script src="script.js"></script>
149
+ <script>
150
+ // Fetch the JSON file
151
+ fetch('data.json')
152
+ .then(response => response.json())
153
+ .then(json_data => {
154
+ // Populate the models and questions.
155
+ populateModels(json_data.models);
156
+ populateQuestions(json_data.questions);
157
+ displayQuestion(currentQuestionIndex);
158
+ }).catch(error => console.error(error));
159
+ </script>
160
+ </body>
161
+
162
+ </html>
llama.jpg ADDED
llava_arch.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Baichuan Zhou , Junlong Jia, Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from tinyllava.model.multimodal_encoder.builder import build_vision_tower
22
+ from tinyllava.model.multimodal_projector.builder import build_vision_projector
23
+
24
+ from tinyllava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
+
26
+ from tinyllava.mm_utils import get_anyres_image_grid_shape
27
+
28
+
29
+ class LlavaMetaModel:
30
+
31
+ def __init__(self, config):
32
+ super(LlavaMetaModel, self).__init__(config)
33
+
34
+ if hasattr(config, "mm_vision_tower"):
35
+ self.vision_tower = build_vision_tower(config, delay_load=True)
36
+ self.mm_projector = build_vision_projector(config)
37
+
38
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
39
+ self.image_newline = nn.Parameter(
40
+ torch.empty(config.hidden_size, dtype=self.dtype)
41
+ )
42
+
43
+ def get_vision_tower(self):
44
+ vision_tower = getattr(self, 'vision_tower', None)
45
+ if type(vision_tower) is list:
46
+ vision_tower = vision_tower[0]
47
+ return vision_tower
48
+
49
+ def initialize_vision_modules(self, model_args, fsdp=None):
50
+ vision_tower = model_args.vision_tower
51
+ mm_vision_select_layer = model_args.mm_vision_select_layer
52
+ mm_vision_select_feature = model_args.mm_vision_select_feature
53
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
54
+ mm_patch_merge_type = model_args.mm_patch_merge_type
55
+
56
+ self.config.mm_vision_tower = vision_tower
57
+
58
+ if self.get_vision_tower() is None:
59
+ vision_tower = build_vision_tower(model_args)
60
+
61
+ if fsdp is not None and len(fsdp) > 0:
62
+ self.vision_tower = [vision_tower]
63
+ else:
64
+ self.vision_tower = vision_tower
65
+
66
+ elif self.get_vision_tower().vision_tower_name != vision_tower:
67
+ print(f"rebuilding vision tower! vision tower initialized from: {vision_tower}")
68
+ vision_tower = build_vision_tower(model_args)
69
+ if fsdp is not None and len(fsdp) > 0:
70
+ self.vision_tower = [vision_tower]
71
+ else:
72
+ self.vision_tower = vision_tower
73
+
74
+ else:
75
+ if fsdp is not None and len(fsdp) > 0:
76
+ vision_tower = self.vision_tower[0]
77
+ else:
78
+ vision_tower = self.vision_tower
79
+ vision_tower.load_model()
80
+
81
+ self.config.use_mm_proj = True
82
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
83
+ self.config.mm_hidden_size = vision_tower.hidden_size
84
+ self.config.mm_vision_select_layer = mm_vision_select_layer
85
+ self.config.mm_vision_select_feature = mm_vision_select_feature
86
+ self.config.mm_patch_merge_type = mm_patch_merge_type
87
+
88
+ if getattr(self, 'mm_projector', None) is None:
89
+ self.mm_projector = build_vision_projector(self.config)
90
+
91
+ if 'unpad' in mm_patch_merge_type:
92
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
93
+ self.image_newline = nn.Parameter(
94
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
95
+ )
96
+ else:
97
+ # In case it is frozen by LoRA
98
+ for p in self.mm_projector.parameters():
99
+ p.requires_grad = True
100
+
101
+ if pretrain_mm_mlp_adapter is not None:
102
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
103
+ def get_w(weights, keyword):
104
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
105
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
106
+
107
+
108
+ def unpad_image(tensor, original_size):
109
+ """
110
+ Unpads a PyTorch tensor of a padded and resized image.
111
+
112
+ Args:
113
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
114
+ original_size (tuple): The original size of the image (height, width).
115
+
116
+ Returns:
117
+ torch.Tensor: The unpadded image tensor.
118
+ """
119
+ original_width, original_height = original_size
120
+ current_height, current_width = tensor.shape[1:]
121
+
122
+ original_aspect_ratio = original_width / original_height
123
+ current_aspect_ratio = current_width / current_height
124
+
125
+ if original_aspect_ratio > current_aspect_ratio:
126
+ scale_factor = current_width / original_width
127
+ new_height = int(original_height * scale_factor)
128
+ padding = (current_height - new_height) // 2
129
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
130
+ else:
131
+ scale_factor = current_height / original_height
132
+ new_width = int(original_width * scale_factor)
133
+ padding = (current_width - new_width) // 2
134
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
135
+
136
+ return unpadded_tensor
137
+
138
+
139
+ class LlavaMetaForCausalLM(ABC):
140
+
141
+ @abstractmethod
142
+ def get_model(self):
143
+ pass
144
+
145
+ def get_vision_tower(self):
146
+ return self.get_model().get_vision_tower()
147
+
148
+ def encode_images(self, images):
149
+ image_features = self.get_model().get_vision_tower()(images)
150
+ image_features = self.get_model().mm_projector(image_features)
151
+ return image_features
152
+
153
+ def prepare_inputs_labels_for_multimodal(
154
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
155
+ images, image_sizes=None
156
+ ):
157
+ vision_tower = self.get_vision_tower()
158
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
159
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
160
+
161
+ if type(images) is list or images.ndim == 5:
162
+ if type(images) is list:
163
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
164
+ concat_images = torch.cat([image for image in images], dim=0)
165
+ image_features = self.encode_images(concat_images)
166
+ split_sizes = [image.shape[0] for image in images]
167
+ image_features = torch.split(image_features, split_sizes, dim=0)
168
+ mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
169
+ image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
170
+ if mm_patch_merge_type == 'flat':
171
+ image_features = [x.flatten(0, 1) for x in image_features]
172
+ elif mm_patch_merge_type.startswith('spatial'):
173
+ new_image_features = []
174
+ for image_idx, image_feature in enumerate(image_features):
175
+ if image_feature.shape[0] > 1:
176
+ base_image_feature = image_feature[0]
177
+ image_feature = image_feature[1:]
178
+ height = width = self.get_vision_tower().num_patches_per_side
179
+ assert height * width == base_image_feature.shape[0]
180
+ if image_aspect_ratio == 'anyres':
181
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
182
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
183
+ else:
184
+ raise NotImplementedError
185
+ if 'unpad' in mm_patch_merge_type:
186
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
187
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
188
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
189
+ image_feature = torch.cat((
190
+ image_feature,
191
+ self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
192
+ ), dim=-1)
193
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
194
+ else:
195
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
196
+ image_feature = image_feature.flatten(0, 3)
197
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
198
+ else:
199
+ image_feature = image_feature[0]
200
+ if 'unpad' in mm_patch_merge_type:
201
+ image_feature = torch.cat((
202
+ image_feature,
203
+ self.model.image_newline[None].to(image_feature.device)
204
+ ), dim=0)
205
+ new_image_features.append(image_feature)
206
+ image_features = new_image_features
207
+ else:
208
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
209
+ else:
210
+ image_features = self.encode_images(images)
211
+
212
+ # TODO: image start / end is not implemented here to support pretraining.
213
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
214
+ raise NotImplementedError
215
+
216
+ # Let's just add dummy tensors if they do not exist,
217
+ # it is a headache to deal with None all the time.
218
+ # But it is not ideal, and if you have a better idea,
219
+ # please open an issue / submit a PR, thanks.
220
+ _labels = labels
221
+ _position_ids = position_ids
222
+ _attention_mask = attention_mask
223
+ if attention_mask is None:
224
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
225
+ else:
226
+ attention_mask = attention_mask.bool()
227
+ if position_ids is None:
228
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
229
+ if labels is None:
230
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
231
+
232
+ # remove the padding using attention_mask -- FIXME
233
+ _input_ids = input_ids
234
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
235
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
236
+
237
+ new_input_embeds = []
238
+ new_labels = []
239
+ cur_image_idx = 0
240
+ for batch_idx, cur_input_ids in enumerate(input_ids):
241
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
242
+ if num_images == 0:
243
+ cur_image_features = image_features[cur_image_idx]
244
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
245
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
246
+ new_input_embeds.append(cur_input_embeds)
247
+ new_labels.append(labels[batch_idx])
248
+ cur_image_idx += 1
249
+ continue
250
+
251
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
252
+ cur_input_ids_noim = []
253
+ cur_labels = labels[batch_idx]
254
+ cur_labels_noim = []
255
+ for i in range(len(image_token_indices) - 1):
256
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
257
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
258
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
259
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
260
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
261
+ cur_new_input_embeds = []
262
+ cur_new_labels = []
263
+
264
+ for i in range(num_images + 1):
265
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
266
+ cur_new_labels.append(cur_labels_noim[i])
267
+ if i < num_images:
268
+ cur_image_features = image_features[cur_image_idx]
269
+ cur_image_idx += 1
270
+ cur_new_input_embeds.append(cur_image_features)
271
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
272
+
273
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
274
+
275
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
276
+ cur_new_labels = torch.cat(cur_new_labels)
277
+
278
+ new_input_embeds.append(cur_new_input_embeds)
279
+ new_labels.append(cur_new_labels)
280
+
281
+ # Truncate sequences to max length as image embeddings can make the sequence longer
282
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
283
+ if tokenizer_model_max_length is not None:
284
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
285
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
286
+
287
+ # Combine them
288
+ max_len = max(x.shape[0] for x in new_input_embeds)
289
+ batch_size = len(new_input_embeds)
290
+
291
+ new_input_embeds_padded = []
292
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
293
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
294
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
295
+
296
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
297
+ cur_len = cur_new_embed.shape[0]
298
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
299
+ new_input_embeds_padded.append(torch.cat((
300
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
301
+ cur_new_embed
302
+ ), dim=0))
303
+ if cur_len > 0:
304
+ new_labels_padded[i, -cur_len:] = cur_new_labels
305
+ attention_mask[i, -cur_len:] = True
306
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
307
+ else:
308
+ new_input_embeds_padded.append(torch.cat((
309
+ cur_new_embed,
310
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
311
+ ), dim=0))
312
+ if cur_len > 0:
313
+ new_labels_padded[i, :cur_len] = cur_new_labels
314
+ attention_mask[i, :cur_len] = True
315
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
316
+
317
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
318
+
319
+ if _labels is None:
320
+ new_labels = None
321
+ else:
322
+ new_labels = new_labels_padded
323
+
324
+ if _attention_mask is None:
325
+ attention_mask = None
326
+ else:
327
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
328
+
329
+ if _position_ids is None:
330
+ position_ids = None
331
+
332
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
333
+
334
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
335
+ if model_args.mm_use_im_patch_token:
336
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
337
+ self.resize_token_embeddings(len(tokenizer))
338
+
339
+ if model_args.mm_use_im_start_end:
340
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
341
+ self.resize_token_embeddings(len(tokenizer))
342
+
343
+ if num_new_tokens > 0:
344
+ input_embeddings = self.get_input_embeddings().weight.data
345
+ output_embeddings = self.get_output_embeddings().weight.data
346
+
347
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
348
+ dim=0, keepdim=True)
349
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
350
+ dim=0, keepdim=True)
351
+
352
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
353
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
354
+
355
+ if model_args.tune_mm_mlp_adapter:
356
+ for p in self.get_input_embeddings().parameters():
357
+ p.requires_grad = True
358
+ for p in self.get_output_embeddings().parameters():
359
+ p.requires_grad = False
360
+
361
+ if model_args.pretrain_mm_mlp_adapter:
362
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
363
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
364
+ assert num_new_tokens == 2
365
+ if input_embeddings.shape == embed_tokens_weight.shape:
366
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
367
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
368
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
369
+ else:
370
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
371
+ elif model_args.tune_embed_tokens:
372
+ for p in self.get_input_embeddings().parameters():
373
+ p.requires_grad = True
374
+ for p in self.get_output_embeddings().parameters():
375
+ p.requires_grad = False
376
+ print("Set input embeddings to trainable")
377
+
378
+ elif model_args.mm_use_im_patch_token:
379
+ if model_args.tune_mm_mlp_adapter:
380
+ for p in self.get_input_embeddings().parameters():
381
+ p.requires_grad = False
382
+ for p in self.get_output_embeddings().parameters():
383
+ p.requires_grad = False
384
+
385
+ if model_args.pretrain_mm_mlp_adapter:
386
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
387
+ if 'model.embed_tokens.weight' in mm_projector_weights.keys():
388
+ def get_w(weights, keyword):
389
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
390
+ self.get_model().embed_tokens.load_state_dict(get_w(mm_projector_weights, 'model.embed_tokens'))
llava_llama.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Baichuan Zhou , Junlong Jia, Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, \
22
+ LlamaConfig, LlamaModel, LlamaForCausalLM
23
+ from transformers.generation.utils import GenerateOutput
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+
27
+ from tinyllava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+ from tinyllava.model.model_factory import *
29
+
30
+ class TinyLlavaConfig(LlamaConfig):
31
+ model_type = "tiny_llava"
32
+
33
+
34
+ class TinyLlavaLlamaModel(LlavaMetaModel, LlamaModel):
35
+ config_class = TinyLlavaConfig
36
+
37
+ def __init__(self, config: LlamaConfig):
38
+ super(TinyLlavaLlamaModel, self).__init__(config)
39
+
40
+ @register_model('llama')
41
+ class TinyLlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = TinyLlavaConfig
43
+
44
+ def __init__(self, config):
45
+ super(LlamaForCausalLM, self).__init__(config)
46
+ self.model = TinyLlavaLlamaModel(config)
47
+ self.pretraining_tp = config.pretraining_tp
48
+ self.vocab_size = config.vocab_size
49
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
+
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ # image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+
73
+ if inputs_embeds is None:
74
+ (
75
+ input_ids,
76
+ position_ids,
77
+ attention_mask,
78
+ past_key_values,
79
+ inputs_embeds,
80
+ labels
81
+ ) = self.prepare_inputs_labels_for_multimodal(
82
+ input_ids,
83
+ position_ids,
84
+ attention_mask,
85
+ past_key_values,
86
+ labels,
87
+ images,
88
+ # image_sizes
89
+ )
90
+
91
+ return super().forward(
92
+ input_ids=input_ids,
93
+ attention_mask=attention_mask,
94
+ position_ids=position_ids,
95
+ past_key_values=past_key_values,
96
+ inputs_embeds=inputs_embeds,
97
+ labels=labels,
98
+ use_cache=use_cache,
99
+ output_attentions=output_attentions,
100
+ output_hidden_states=output_hidden_states,
101
+ return_dict=return_dict
102
+ )
103
+
104
+ @torch.no_grad()
105
+ def generate(
106
+ self,
107
+ inputs: Optional[torch.Tensor] = None,
108
+ images: Optional[torch.Tensor] = None,
109
+ # image_sizes: Optional[torch.Tensor] = None,
110
+ **kwargs,
111
+ ) -> Union[GenerateOutput, torch.LongTensor]:
112
+ position_ids = kwargs.pop("position_ids", None)
113
+ attention_mask = kwargs.pop("attention_mask", None)
114
+ if "inputs_embeds" in kwargs:
115
+ raise NotImplementedError("`inputs_embeds` is not supported")
116
+
117
+ if images is not None:
118
+ (
119
+ inputs,
120
+ position_ids,
121
+ attention_mask,
122
+ _,
123
+ inputs_embeds,
124
+ _
125
+ ) = self.prepare_inputs_labels_for_multimodal(
126
+ inputs,
127
+ position_ids,
128
+ attention_mask,
129
+ None,
130
+ None,
131
+ images,
132
+ # image_sizes=image_sizes
133
+ )
134
+ else:
135
+ inputs_embeds = self.get_model().embed_tokens(inputs)
136
+
137
+ return super().generate(
138
+ position_ids=position_ids,
139
+ attention_mask=attention_mask,
140
+ inputs_embeds=inputs_embeds,
141
+ **kwargs
142
+ )
143
+
144
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145
+ inputs_embeds=None, **kwargs):
146
+ images = kwargs.pop("images", None)
147
+ image_sizes = kwargs.pop("image_sizes", None)
148
+ inputs = super().prepare_inputs_for_generation(
149
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150
+ )
151
+ if images is not None:
152
+ inputs['images'] = images
153
+ if image_sizes is not None:
154
+ inputs['image_sizes'] = image_sizes
155
+ return inputs
156
+
157
+ @register_tokenizer('llama')
158
+ def get_tokenizer():
159
+ from transformers import AutoTokenizer
160
+ def post_init(tokenizer):
161
+ return tokenizer
162
+ return AutoTokenizer, post_init
163
+
164
+ AutoConfig.register("tiny_llava", TinyLlavaConfig)
165
+ AutoModelForCausalLM.register(TinyLlavaConfig, TinyLlavaLlamaForCausalLM)
llava_phi.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Baichuan Zhou , Junlong Jia, Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM
23
+
24
+ from transformers import PhiConfig, PhiModel, PhiForCausalLM
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from transformers.modeling_outputs import CausalLMOutputWithPast
28
+
29
+ from tinyllava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
30
+ from tinyllava.model.model_factory import *
31
+
32
+ class TinyLlavaPhiConfig(PhiConfig):
33
+ model_type = "tiny_llava_phi"
34
+
35
+
36
+ class TinyLlavaPhiModel(LlavaMetaModel, PhiModel):
37
+ config_class = TinyLlavaPhiConfig
38
+
39
+ def __init__(self, config: PhiConfig):
40
+ super(TinyLlavaPhiModel, self).__init__(config)
41
+ self.gradient_checkpointing = False
42
+
43
+ @register_model('phi')
44
+ class TinyLlavaPhiForCausalLM(PhiForCausalLM, LlavaMetaForCausalLM):
45
+ config_class = TinyLlavaPhiConfig
46
+
47
+ def __init__(self, config):
48
+ super(PhiForCausalLM, self).__init__(config)
49
+ self.model = TinyLlavaPhiModel(config)
50
+ # self.pretraining_tp = config.pretraining_tp
51
+ self.vocab_size = config.vocab_size
52
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
53
+
54
+ # Initialize weights and apply final processing
55
+ self.post_init()
56
+
57
+ def get_model(self):
58
+ return self.model
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: torch.LongTensor = None,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ position_ids: Optional[torch.LongTensor] = None,
65
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
66
+ inputs_embeds: Optional[torch.FloatTensor] = None,
67
+ labels: Optional[torch.LongTensor] = None,
68
+ use_cache: Optional[bool] = None,
69
+ output_attentions: Optional[bool] = None,
70
+ output_hidden_states: Optional[bool] = None,
71
+ images: Optional[torch.FloatTensor] = None,
72
+ # image_sizes: Optional[List[List[int]]] = None,
73
+ return_dict: Optional[bool] = None,
74
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
75
+
76
+ if inputs_embeds is None:
77
+ (
78
+ input_ids,
79
+ position_ids,
80
+ attention_mask,
81
+ past_key_values,
82
+ inputs_embeds,
83
+ labels
84
+ ) = self.prepare_inputs_labels_for_multimodal(
85
+ input_ids,
86
+ position_ids,
87
+ attention_mask,
88
+ past_key_values,
89
+ labels,
90
+ images,
91
+ # image_sizes
92
+ )
93
+
94
+ return super().forward(
95
+ input_ids=input_ids,
96
+ attention_mask=attention_mask,
97
+ position_ids=position_ids,
98
+ past_key_values=past_key_values,
99
+ inputs_embeds=inputs_embeds,
100
+ labels=labels,
101
+ use_cache=use_cache,
102
+ output_attentions=output_attentions,
103
+ output_hidden_states=output_hidden_states,
104
+ return_dict=return_dict
105
+ )
106
+
107
+ @torch.no_grad()
108
+ def generate(
109
+ self,
110
+ inputs: Optional[torch.Tensor] = None,
111
+ images: Optional[torch.Tensor] = None,
112
+ # image_sizes: Optional[torch.Tensor] = None,
113
+ **kwargs,
114
+ ) -> Union[GenerateOutput, torch.LongTensor]:
115
+ position_ids = kwargs.pop("position_ids", None)
116
+ attention_mask = kwargs.pop("attention_mask", None)
117
+ if "inputs_embeds" in kwargs:
118
+ raise NotImplementedError("`inputs_embeds` is not supported")
119
+
120
+ if images is not None:
121
+ (
122
+ inputs,
123
+ position_ids,
124
+ attention_mask,
125
+ _,
126
+ inputs_embeds,
127
+ _
128
+ ) = self.prepare_inputs_labels_for_multimodal(
129
+ inputs,
130
+ position_ids,
131
+ attention_mask,
132
+ None,
133
+ None,
134
+ images,
135
+ # image_sizes=image_sizes
136
+ )
137
+ else:
138
+ inputs_embeds = self.get_model().embed_tokens(inputs)
139
+
140
+ return super().generate(
141
+ position_ids=position_ids,
142
+ attention_mask=attention_mask,
143
+ inputs_embeds=inputs_embeds,
144
+ **kwargs
145
+ )
146
+
147
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
148
+ inputs_embeds=None, **kwargs):
149
+ images = kwargs.pop("images", None)
150
+ image_sizes = kwargs.pop("image_sizes", None)
151
+ inputs = super().prepare_inputs_for_generation(
152
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
153
+ )
154
+ if images is not None:
155
+ inputs['images'] = images
156
+ if image_sizes is not None:
157
+ inputs['image_sizes'] = image_sizes
158
+ return inputs
159
+
160
+ @register_tokenizer('phi')
161
+ def get_tokenizer():
162
+ from transformers import AutoTokenizer
163
+ def post_init(tokenizer):
164
+ return tokenizer
165
+ return AutoTokenizer, post_init
166
+
167
+ AutoConfig.register("tiny_llava_phi", TinyLlavaPhiConfig)
168
+ AutoModelForCausalLM.register(TinyLlavaPhiConfig, TinyLlavaPhiForCausalLM)
llava_qwen.py ADDED
File without changes
llava_stablelm.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Baichuan Zhou , Junlong Jia, Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
22
+ from transformers.generation.utils import GenerateOutput
23
+ from tinyllava.model.language_model.stablelm.configuration_stablelm_epoch import StableLMEpochConfig
24
+ from tinyllava.model.language_model.stablelm.modeling_stablelm_epoch import StableLMEpochModel, StableLMEpochForCausalLM
25
+
26
+ from transformers.modeling_outputs import CausalLMOutputWithPast
27
+
28
+ from tinyllava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
+ from tinyllava.model.model_factory import *
30
+
31
+ import torch.distributed as dist
32
+
33
+
34
+ class TinyLlavaStablelmConfig(StableLMEpochConfig):
35
+ model_type = "tiny_llava_stablelm"
36
+
37
+
38
+ class TinyLlavaStablelmModel(LlavaMetaModel, StableLMEpochModel):
39
+ config_class = TinyLlavaStablelmConfig
40
+
41
+ def __init__(self, config: StableLMEpochConfig):
42
+ super(TinyLlavaStablelmModel, self).__init__(config)
43
+
44
+ @register_model('stablelm')
45
+ class TinyLlavaStablelmForCausalLM(StableLMEpochForCausalLM, LlavaMetaForCausalLM):
46
+ config_class = TinyLlavaStablelmConfig
47
+
48
+ def __init__(self, config):
49
+ super(StableLMEpochForCausalLM, self).__init__(config)
50
+ self.model = TinyLlavaStablelmModel(config)
51
+ self.vocab_size = config.vocab_size
52
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
53
+
54
+ # Initialize weights and apply final processing
55
+ self.post_init()
56
+
57
+ def get_model(self):
58
+ return self.model
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: torch.LongTensor = None,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ position_ids: Optional[torch.LongTensor] = None,
65
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
66
+ inputs_embeds: Optional[torch.FloatTensor] = None,
67
+ labels: Optional[torch.LongTensor] = None,
68
+ use_cache: Optional[bool] = None,
69
+ output_attentions: Optional[bool] = None,
70
+ output_hidden_states: Optional[bool] = None,
71
+ images: Optional[torch.FloatTensor] = None,
72
+ # image_sizes: Optional[List[List[int]]] = None,
73
+ return_dict: Optional[bool] = None,
74
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
75
+
76
+ if inputs_embeds is None:
77
+ (
78
+ input_ids,
79
+ position_ids,
80
+ attention_mask,
81
+ past_key_values,
82
+ inputs_embeds,
83
+ labels
84
+ ) = self.prepare_inputs_labels_for_multimodal(
85
+ input_ids,
86
+ position_ids,
87
+ attention_mask,
88
+ past_key_values,
89
+ labels,
90
+ images,
91
+ # image_sizes
92
+ )
93
+
94
+ return super().forward(
95
+ input_ids=input_ids,
96
+ attention_mask=attention_mask,
97
+ position_ids=position_ids,
98
+ past_key_values=past_key_values,
99
+ inputs_embeds=inputs_embeds,
100
+ labels=labels,
101
+ use_cache=use_cache,
102
+ output_attentions=output_attentions,
103
+ output_hidden_states=output_hidden_states,
104
+ return_dict=return_dict
105
+ )
106
+
107
+ @torch.no_grad()
108
+ def generate(
109
+ self,
110
+ inputs: Optional[torch.Tensor] = None,
111
+ images: Optional[torch.Tensor] = None,
112
+ # image_sizes: Optional[torch.Tensor] = None,
113
+ **kwargs,
114
+ ) -> Union[GenerateOutput, torch.LongTensor]:
115
+ position_ids = kwargs.pop("position_ids", None)
116
+ attention_mask = kwargs.pop("attention_mask", None)
117
+ if "inputs_embeds" in kwargs:
118
+ raise NotImplementedError("`inputs_embeds` is not supported")
119
+
120
+ if images is not None:
121
+ (
122
+ inputs,
123
+ position_ids,
124
+ attention_mask,
125
+ _,
126
+ inputs_embeds,
127
+ _
128
+ ) = self.prepare_inputs_labels_for_multimodal(
129
+ inputs,
130
+ position_ids,
131
+ attention_mask,
132
+ None,
133
+ None,
134
+ images,
135
+ # image_sizes=image_sizes
136
+ )
137
+ else:
138
+ inputs_embeds = self.get_model().embed_tokens(inputs)
139
+
140
+ return super().generate(
141
+ position_ids=position_ids,
142
+ attention_mask=attention_mask,
143
+ inputs_embeds=inputs_embeds,
144
+ **kwargs
145
+ )
146
+
147
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
148
+ inputs_embeds=None, **kwargs):
149
+ images = kwargs.pop("images", None)
150
+ image_sizes = kwargs.pop("image_sizes", None)
151
+ inputs = super().prepare_inputs_for_generation(
152
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
153
+ )
154
+ if images is not None:
155
+ inputs['images'] = images
156
+ if image_sizes is not None:
157
+ inputs['image_sizes'] = image_sizes
158
+ return inputs
159
+
160
+ '''
161
+ @register_tokenizer('stablelm')
162
+ def get_tokenizer():
163
+ from .stablelm.tokenization_arcade100k import Arcade100kTokenizer
164
+ def post_init(tokenizer):
165
+ tokenizer.unk_token = tokenizer.pad_token
166
+ return tokenizer
167
+ return Arcade100kTokenizer, post_init
168
+ '''
169
+
170
+ @register_tokenizer('stablelm')
171
+ def get_tokenizer():
172
+ from transformers import AutoTokenizer
173
+ def post_init(tokenizer):
174
+ return tokenizer
175
+ return AutoTokenizer, post_init
176
+
177
+ AutoConfig.register("tiny_llava_stablelm", TinyLlavaStablelmConfig)
178
+ AutoModelForCausalLM.register(TinyLlavaStablelmConfig, TinyLlavaStablelmForCausalLM)
llava_trainer.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch import nn
4
+
5
+ from torch.utils.data import Sampler
6
+
7
+ from transformers import Trainer
8
+ from transformers.trainer import (
9
+ is_sagemaker_mp_enabled,
10
+ get_parameter_names,
11
+ has_length,
12
+ ALL_LAYERNORM_LAYERS,
13
+ # ShardedDDPOption,
14
+ logger,
15
+ )
16
+ from typing import List, Optional
17
+
18
+
19
+ def maybe_zero_3(param, ignore_status=False, name=None):
20
+ from deepspeed import zero
21
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
22
+ if hasattr(param, "ds_id"):
23
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
24
+ if not ignore_status:
25
+ print(name, 'no ignore status')
26
+ with zero.GatheredParameters([param]):
27
+ param = param.data.detach().cpu().clone()
28
+ else:
29
+ param = param.detach().cpu().clone()
30
+ return param
31
+
32
+
33
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
34
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
35
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
36
+ return to_return
37
+
38
+
39
+ def split_to_even_chunks(indices, lengths, num_chunks):
40
+ """
41
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
42
+ """
43
+
44
+ if len(indices) % num_chunks != 0:
45
+ return [indices[i::num_chunks] for i in range(num_chunks)]
46
+
47
+ num_indices_per_chunk = len(indices) // num_chunks
48
+
49
+ chunks = [[] for _ in range(num_chunks)]
50
+ chunks_lengths = [0 for _ in range(num_chunks)]
51
+ for index in indices:
52
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
53
+ chunks[shortest_chunk].append(index)
54
+ chunks_lengths[shortest_chunk] += lengths[index]
55
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
56
+ chunks_lengths[shortest_chunk] = float("inf")
57
+
58
+ return chunks
59
+
60
+
61
+ def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
62
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
63
+ assert all(l != 0 for l in lengths), "Should not have zero length."
64
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
65
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
66
+
67
+ assert len(mm_indices) > 0, "Should have at least one multimodal sample."
68
+ assert len(lang_indices) > 0, "Should have at least one language sample."
69
+
70
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
71
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
72
+ megabatch_size = world_size * batch_size
73
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
74
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
75
+
76
+ last_mm = mm_megabatches[-1]
77
+ last_lang = lang_megabatches[-1]
78
+ additional_batch = last_mm + last_lang
79
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
80
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
81
+ megabatches = [megabatches[i] for i in megabatch_indices]
82
+
83
+ if len(additional_batch) >= megabatch_size:
84
+ megabatches = [additional_batch[:megabatch_size]] + megabatches
85
+ additional_batch = additional_batch[megabatch_size:]
86
+
87
+ if len(additional_batch) > 0:
88
+ megabatches.append(additional_batch)
89
+
90
+ return [i for megabatch in megabatches for i in megabatch]
91
+
92
+
93
+ def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
94
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
95
+ indices = torch.randperm(len(lengths), generator=generator)
96
+ megabatch_size = world_size * batch_size
97
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
98
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
99
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
100
+
101
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
102
+
103
+
104
+ class LengthGroupedSampler(Sampler):
105
+ r"""
106
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
107
+ keeping a bit of randomness.
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ batch_size: int,
113
+ world_size: int,
114
+ lengths: Optional[List[int]] = None,
115
+ generator=None,
116
+ group_by_modality: bool = False,
117
+ ):
118
+ if lengths is None:
119
+ raise ValueError("Lengths must be provided.")
120
+
121
+ self.batch_size = batch_size
122
+ self.world_size = world_size
123
+ self.lengths = lengths
124
+ self.generator = generator
125
+ self.group_by_modality = group_by_modality
126
+
127
+ def __len__(self):
128
+ return len(self.lengths)
129
+
130
+ def __iter__(self):
131
+ if self.group_by_modality:
132
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
133
+ else:
134
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
135
+ return iter(indices)
136
+
137
+
138
+ class LLaVATrainer(Trainer):
139
+
140
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
141
+ if self.train_dataset is None or not has_length(self.train_dataset):
142
+ return None
143
+
144
+ if self.args.group_by_modality_length:
145
+ lengths = self.train_dataset.modality_lengths
146
+ return LengthGroupedSampler(
147
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
148
+ self.args.train_batch_size,
149
+ world_size=self.args.world_size,
150
+ lengths=lengths,
151
+ group_by_modality=True,
152
+ )
153
+ else:
154
+ return super()._get_train_sampler()
155
+
156
+ def create_optimizer(self):
157
+ """
158
+ Setup the optimizer.
159
+
160
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
161
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
162
+ """
163
+ if is_sagemaker_mp_enabled():
164
+ return super().create_optimizer()
165
+ # if self.sharded_ddp == ShardedDDPOption.SIMPLE:
166
+ # return super().create_optimizer()
167
+
168
+ opt_model = self.model
169
+
170
+ if self.optimizer is None:
171
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
172
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
173
+ if self.args.mm_projector_lr is not None:
174
+ projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
175
+ optimizer_grouped_parameters = [
176
+ {
177
+ "params": [
178
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
179
+ ],
180
+ "weight_decay": self.args.weight_decay,
181
+ "name": "decay_no_proj_parameters"
182
+ },
183
+ {
184
+ "params": [
185
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
186
+ ],
187
+ "weight_decay": 0.0,
188
+ "name": "no_decay_no_proj_parameters"
189
+ },
190
+ {
191
+ "params": [
192
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
193
+ ],
194
+ "weight_decay": self.args.weight_decay,
195
+ "lr": self.args.mm_projector_lr,
196
+ "name": "decay_proj_parameters"
197
+ },
198
+ {
199
+ "params": [
200
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
201
+ ],
202
+ "weight_decay": 0.0,
203
+ "lr": self.args.mm_projector_lr,
204
+ "name": "no_decay_proj_parameters"
205
+ },
206
+ ]
207
+ else:
208
+ optimizer_grouped_parameters = [
209
+ {
210
+ "params": [
211
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
212
+ ],
213
+ "weight_decay": self.args.weight_decay,
214
+ "name": "decay_parameters"
215
+ },
216
+ {
217
+ "params": [
218
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
219
+ ],
220
+ "weight_decay": 0.0,
221
+ "name": "no_decay_parameters"
222
+ },
223
+ ]
224
+
225
+ if getattr(self.args, "moe_enable", False):
226
+ from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
227
+ optimizer_grouped_parameters = split_params_into_different_moe_groups_for_optimizer(optimizer_grouped_parameters)
228
+ optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args)
229
+
230
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
231
+ if optimizer_cls.__name__ == "Adam8bit":
232
+ import bitsandbytes
233
+
234
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
235
+
236
+ skipped = 0
237
+ for module in opt_model.modules():
238
+ if isinstance(module, nn.Embedding):
239
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
240
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
241
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
242
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
243
+ logger.info(f"skipped: {skipped/2**20}M params")
244
+
245
+ return self.optimizer
246
+
247
+ def _save_checkpoint(self, model, trial, metrics=None):
248
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
249
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
250
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
251
+
252
+ run_dir = self._get_output_dir(trial=trial)
253
+ output_dir = os.path.join(run_dir, checkpoint_folder)
254
+
255
+ # Only save Adapter
256
+ keys_to_match = ['mm_projector', 'vision_resampler']
257
+ if getattr(self.args, "use_im_start_end", False):
258
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
259
+
260
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
261
+
262
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
263
+ self.model.config.save_pretrained(output_dir)
264
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
265
+ else:
266
+ super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
267
+
268
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
269
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
270
+ pass
271
+ else:
272
+ super(LLaVATrainer, self)._save(output_dir, state_dict)
273
+
274
+
mm_utils.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+
5
+ import torch
6
+ from transformers import StoppingCriteria
7
+ from tinyllava.constants import IMAGE_TOKEN_INDEX
8
+ import math
9
+ import ast
10
+
11
+
12
+ ## added by llava-1.6
13
+ def select_best_resolution(original_size, possible_resolutions):
14
+ """
15
+ Selects the best resolution from a list of possible resolutions based on the original size.
16
+
17
+ Args:
18
+ original_size (tuple): The original size of the image in the format (width, height).
19
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
20
+
21
+ Returns:
22
+ tuple: The best fit resolution in the format (width, height).
23
+ """
24
+ original_width, original_height = original_size
25
+ best_fit = None
26
+ max_effective_resolution = 0
27
+ min_wasted_resolution = float('inf')
28
+
29
+ for width, height in possible_resolutions:
30
+ scale = min(width / original_width, height / original_height)
31
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
32
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
33
+ wasted_resolution = (width * height) - effective_resolution
34
+
35
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
36
+ max_effective_resolution = effective_resolution
37
+ min_wasted_resolution = wasted_resolution
38
+ best_fit = (width, height)
39
+
40
+ return best_fit
41
+
42
+
43
+ ## added by llava-1.6
44
+ def resize_and_pad_image(image, target_resolution):
45
+ """
46
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
47
+
48
+ Args:
49
+ image (PIL.Image.Image): The input image.
50
+ target_resolution (tuple): The target resolution (width, height) of the image.
51
+
52
+ Returns:
53
+ PIL.Image.Image: The resized and padded image.
54
+ """
55
+ original_width, original_height = image.size
56
+ target_width, target_height = target_resolution
57
+
58
+ scale_w = target_width / original_width
59
+ scale_h = target_height / original_height
60
+
61
+ if scale_w < scale_h:
62
+ new_width = target_width
63
+ new_height = min(math.ceil(original_height * scale_w), target_height)
64
+ else:
65
+ new_height = target_height
66
+ new_width = min(math.ceil(original_width * scale_h), target_width)
67
+
68
+ # Resize the image
69
+ resized_image = image.resize((new_width, new_height))
70
+
71
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
72
+ paste_x = (target_width - new_width) // 2
73
+ paste_y = (target_height - new_height) // 2
74
+ new_image.paste(resized_image, (paste_x, paste_y))
75
+
76
+ return new_image
77
+
78
+
79
+ ## added by llava-1.6
80
+ def divide_to_patches(image, patch_size):
81
+ """
82
+ Divides an image into patches of a specified size.
83
+
84
+ Args:
85
+ image (PIL.Image.Image): The input image.
86
+ patch_size (int): The size of each patch.
87
+
88
+ Returns:
89
+ list: A list of PIL.Image.Image objects representing the patches.
90
+ """
91
+ patches = []
92
+ width, height = image.size
93
+ for i in range(0, height, patch_size):
94
+ for j in range(0, width, patch_size):
95
+ box = (j, i, j + patch_size, i + patch_size)
96
+ patch = image.crop(box)
97
+ patches.append(patch)
98
+
99
+ return patches
100
+
101
+
102
+ ## added by llava-1.6
103
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
104
+ """
105
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
106
+
107
+ Args:
108
+ image_size (tuple): The size of the input image in the format (width, height).
109
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
110
+ patch_size (int): The size of each image patch.
111
+
112
+ Returns:
113
+ tuple: The shape of the image patch grid in the format (width, height).
114
+ """
115
+ if type(grid_pinpoints) is list:
116
+ possible_resolutions = grid_pinpoints
117
+ else:
118
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
119
+ width, height = select_best_resolution(image_size, possible_resolutions)
120
+ return width // patch_size, height // patch_size
121
+
122
+
123
+ ## added by llava-1.6
124
+ def process_anyres_image(image, processor, grid_pinpoints):
125
+ """
126
+ Process an image with variable resolutions.
127
+
128
+ Args:
129
+ image (PIL.Image.Image): The input image to be processed.
130
+ processor: The image processor object.
131
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
132
+
133
+ Returns:
134
+ torch.Tensor: A tensor containing the processed image patches.
135
+ """
136
+ if type(grid_pinpoints) is list:
137
+ possible_resolutions = grid_pinpoints
138
+ else:
139
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
140
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
141
+ image_padded = resize_and_pad_image(image, best_resolution)
142
+
143
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
144
+
145
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
146
+
147
+ image_patches = [image_original_resize] + patches
148
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
149
+ for image_patch in image_patches]
150
+ return torch.stack(image_patches, dim=0)
151
+
152
+
153
+ def load_image_from_base64(image):
154
+ return Image.open(BytesIO(base64.b64decode(image)))
155
+
156
+
157
+ def expand2square(pil_img, background_color):
158
+ width, height = pil_img.size
159
+ if width == height:
160
+ return pil_img
161
+ elif width > height:
162
+ result = Image.new(pil_img.mode, (width, width), background_color)
163
+ result.paste(pil_img, (0, (width - height) // 2))
164
+ return result
165
+ else:
166
+ result = Image.new(pil_img.mode, (height, height), background_color)
167
+ result.paste(pil_img, ((height - width) // 2, 0))
168
+ return result
169
+
170
+
171
+ def process_images(images, image_processor, model_cfg):
172
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
173
+ new_images = []
174
+ if image_aspect_ratio == 'pad':
175
+ for image in images:
176
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
177
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
178
+ new_images.append(image)
179
+ elif image_aspect_ratio == "anyres":
180
+ for image in images:
181
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
182
+ new_images.append(image)
183
+ else:
184
+ return image_processor(images, return_tensors='pt')['pixel_values']
185
+ if all(x.shape == new_images[0].shape for x in new_images):
186
+ new_images = torch.stack(new_images, dim=0)
187
+ return new_images
188
+
189
+
190
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
191
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
192
+
193
+ def insert_separator(X, sep):
194
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
195
+
196
+ input_ids = []
197
+ offset = 0
198
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
199
+ offset = 1
200
+ input_ids.append(prompt_chunks[0][0])
201
+
202
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
203
+ input_ids.extend(x[offset:])
204
+
205
+ if return_tensors is not None:
206
+ if return_tensors == 'pt':
207
+ return torch.tensor(input_ids, dtype=torch.long)
208
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
209
+ return input_ids
210
+
211
+
212
+ def get_model_name_from_path(model_path):
213
+ model_path = model_path.strip("/")
214
+ model_paths = model_path.split("/")
215
+ if model_paths[-1].startswith('checkpoint-'):
216
+ return model_paths[-2] + "_" + model_paths[-1]
217
+ else:
218
+ return model_paths[-1]
219
+
220
+
221
+ class KeywordsStoppingCriteria(StoppingCriteria):
222
+ def __init__(self, keywords, tokenizer, input_ids):
223
+ self.keywords = keywords
224
+ self.keyword_ids = []
225
+ self.max_keyword_len = 0
226
+ for keyword in keywords:
227
+ cur_keyword_ids = tokenizer(keyword).input_ids
228
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
229
+ cur_keyword_ids = cur_keyword_ids[1:]
230
+ if len(cur_keyword_ids) > self.max_keyword_len:
231
+ self.max_keyword_len = len(cur_keyword_ids)
232
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
233
+ self.tokenizer = tokenizer
234
+ self.start_len = input_ids.shape[1]
235
+
236
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
237
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
238
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
239
+ for keyword_id in self.keyword_ids:
240
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
241
+ return True
242
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
243
+ for keyword in self.keywords:
244
+ if keyword in outputs:
245
+ return True
246
+ return False
247
+
248
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
249
+ outputs = []
250
+ for i in range(output_ids.shape[0]):
251
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
252
+ return all(outputs)
model.jsonl ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"}
2
+ {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"}
3
+ {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"}
4
+ {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"}
5
+ {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"}
model_factory.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+
4
+ MODEL_REGISTRY = {}
5
+ TOKENIZER_REGISTRY = {}
6
+
7
+
8
+ def ModelSelect(model_name_or_path):
9
+ model = None
10
+ for name in MODEL_REGISTRY.keys():
11
+ if name.lower() in model_name_or_path.lower():
12
+ model = MODEL_REGISTRY[name]
13
+ if model is None:
14
+ model = MODEL_REGISTRY['llama']
15
+ return model
16
+
17
+
18
+ def TokenizerSelect(model_name_or_path):
19
+ tokenizer_init = None
20
+ for name in TOKENIZER_REGISTRY.keys():
21
+ if name.lower() in model_name_or_path.lower():
22
+ tokenizer_init = TOKENIZER_REGISTRY[name]
23
+ if tokenizer_init is None:
24
+ tokenizer_init = TOKENIZER_REGISTRY['llama']
25
+ return tokenizer_init
26
+
27
+
28
+ def register_model(name):
29
+ def register_model_cls(cls):
30
+ if name in MODEL_REGISTRY:
31
+ return MODEL_REGISTRY[name]
32
+
33
+ MODEL_REGISTRY[name] = cls
34
+ # FIXME: Find a more elegant way to do this
35
+ if name == 'phi':
36
+ MODEL_REGISTRY['TinyLLaVA-3.1B'] = cls
37
+ elif name == 'stablelm':
38
+ MODEL_REGISTRY['TinyLLaVA-2.0B'] = cls
39
+ elif name == 'llama':
40
+ MODEL_REGISTRY['TinyLLaVA-1.5B'] = cls
41
+ return cls
42
+
43
+ return register_model_cls
44
+
45
+
46
+ def register_tokenizer(name):
47
+ def register_tokenizer_cls(cls):
48
+ if name in TOKENIZER_REGISTRY:
49
+ return TOKENIZER_REGISTRY[name]
50
+
51
+ TOKENIZER_REGISTRY[name] = cls
52
+ # FIXME: Find a more elegant way to do this
53
+ if name == 'phi':
54
+ TOKENIZER_REGISTRY['TinyLLaVA-3.1B'] = cls
55
+ elif name == 'stablelm':
56
+ TOKENIZER_REGISTRY['TinyLLaVA-2.0B'] = cls
57
+ elif name == 'llama':
58
+ TOKENIZER_REGISTRY['TinyLLaVA-1.5B'] = cls
59
+ return cls
60
+
61
+ return register_tokenizer_cls
62
+
63
+
64
+ def import_models(models_dir, namespace):
65
+ for file in os.listdir(models_dir):
66
+ path = os.path.join(models_dir, file)
67
+ if (
68
+ not file.startswith("_")
69
+ and not file.startswith(".")
70
+ and file.endswith(".py")
71
+ ):
72
+ model_name = file[: file.find(".py")] if file.endswith(".py") else file
73
+ importlib.import_module(namespace + "." + model_name)
74
+
75
+
76
+ # automatically import any Python files in the models/ directory
77
+ models_dir = os.path.join(os.path.dirname(__file__), 'language_model')
78
+ import_models(models_dir, "tinyllava.model.language_model")
model_qa.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
3
+ import torch
4
+ import os
5
+ import json
6
+ from tqdm import tqdm
7
+ import shortuuid
8
+
9
+ from tinyllava.conversation import default_conversation
10
+ from tinyllava.utils import disable_torch_init
11
+
12
+
13
+ # new stopping implementation
14
+ class KeywordsStoppingCriteria(StoppingCriteria):
15
+ def __init__(self, keywords, tokenizer, input_ids):
16
+ self.keywords = keywords
17
+ self.tokenizer = tokenizer
18
+ self.start_len = None
19
+ self.input_ids = input_ids
20
+
21
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
22
+ if self.start_len is None:
23
+ self.start_len = self.input_ids.shape[1]
24
+ else:
25
+ outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
26
+ for keyword in self.keywords:
27
+ if keyword in outputs:
28
+ return True
29
+ return False
30
+
31
+
32
+ @torch.inference_mode()
33
+ def eval_model(model_name, questions_file, answers_file):
34
+ # Model
35
+ disable_torch_init()
36
+ model_name = os.path.expanduser(model_name)
37
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
38
+ model = AutoModelForCausalLM.from_pretrained(model_name,
39
+ torch_dtype=torch.float16).cuda()
40
+
41
+
42
+ ques_file = open(os.path.expanduser(questions_file), "r")
43
+ ans_file = open(os.path.expanduser(answers_file), "w")
44
+ for i, line in enumerate(tqdm(ques_file)):
45
+ idx = json.loads(line)["question_id"]
46
+ qs = json.loads(line)["text"]
47
+ cat = json.loads(line)["category"]
48
+ conv = default_conversation.copy()
49
+ conv.append_message(conv.roles[0], qs)
50
+ prompt = conv.get_prompt()
51
+ inputs = tokenizer([prompt])
52
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
53
+ stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids)
54
+ output_ids = model.generate(
55
+ input_ids,
56
+ do_sample=True,
57
+ use_cache=True,
58
+ temperature=0.7,
59
+ max_new_tokens=1024,
60
+ stopping_criteria=[stopping_criteria])
61
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
62
+ try:
63
+ index = outputs.index(conv.sep, len(prompt))
64
+ except ValueError:
65
+ outputs += conv.sep
66
+ index = outputs.index(conv.sep, len(prompt))
67
+
68
+ outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
69
+ ans_id = shortuuid.uuid()
70
+ ans_file.write(json.dumps({"question_id": idx,
71
+ "text": outputs,
72
+ "answer_id": ans_id,
73
+ "model_id": model_name,
74
+ "metadata": {}}) + "\n")
75
+ ans_file.flush()
76
+ ans_file.close()
77
+
78
+ if __name__ == "__main__":
79
+ parser = argparse.ArgumentParser()
80
+ parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
81
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
82
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
83
+ args = parser.parse_args()
84
+
85
+ eval_model(args.model_name, args.question_file, args.answers_file)
model_vqa.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from tinyllava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from tinyllava.conversation import conv_templates, SeparatorStyle
10
+ from tinyllava.model.builder import load_pretrained_model
11
+ from tinyllava.utils import disable_torch_init
12
+ from tinyllava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13
+
14
+ from PIL import Image
15
+ import math
16
+
17
+
18
+ def split_list(lst, n):
19
+ """Split a list into n (roughly) equal-sized chunks"""
20
+ chunk_size = math.ceil(len(lst) / n) # integer division
21
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22
+
23
+
24
+ def get_chunk(lst, n, k):
25
+ chunks = split_list(lst, n)
26
+ return chunks[k]
27
+
28
+
29
+ def eval_model(args):
30
+ # Model
31
+ disable_torch_init()
32
+ model_path = os.path.expanduser(args.model_path)
33
+ model_name = get_model_name_from_path(model_path)
34
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
35
+
36
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
37
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
38
+ answers_file = os.path.expanduser(args.answers_file)
39
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
40
+ ans_file = open(answers_file, "w")
41
+ for line in tqdm(questions):
42
+ idx = line["question_id"]
43
+ image_file = line["image"]
44
+ qs = line["text"]
45
+ cur_prompt = qs
46
+ if model.config.mm_use_im_start_end:
47
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
48
+ else:
49
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
50
+
51
+ conv = conv_templates[args.conv_mode].copy()
52
+ conv.append_message(conv.roles[0], qs)
53
+ conv.append_message(conv.roles[1], None)
54
+ prompt = conv.get_prompt()
55
+
56
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
57
+
58
+ image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB')
59
+ image_tensor = process_images([image], image_processor, model.config)[0]
60
+
61
+ with torch.inference_mode():
62
+ output_ids = model.generate(
63
+ input_ids,
64
+ images=image_tensor.unsqueeze(0).half().cuda(),
65
+ do_sample=True if args.temperature > 0 else False,
66
+ temperature=args.temperature,
67
+ top_p=args.top_p,
68
+ num_beams=args.num_beams,
69
+ # no_repeat_ngram_size=3,
70
+ max_new_tokens=1024,
71
+ use_cache=True)
72
+
73
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
74
+
75
+ ans_id = shortuuid.uuid()
76
+ ans_file.write(json.dumps({"question_id": idx,
77
+ "prompt": cur_prompt,
78
+ "text": outputs,
79
+ "answer_id": ans_id,
80
+ "model_id": model_name,
81
+ "metadata": {}}) + "\n")
82
+ ans_file.flush()
83
+ ans_file.close()
84
+
85
+ if __name__ == "__main__":
86
+ parser = argparse.ArgumentParser()
87
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
88
+ parser.add_argument("--model-base", type=str, default=None)
89
+ parser.add_argument("--image-folder", type=str, default="")
90
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
91
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
92
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
93
+ parser.add_argument("--num-chunks", type=int, default=1)
94
+ parser.add_argument("--chunk-idx", type=int, default=0)
95
+ parser.add_argument("--temperature", type=float, default=0.2)
96
+ parser.add_argument("--top_p", type=float, default=None)
97
+ parser.add_argument("--num_beams", type=int, default=1)
98
+ args = parser.parse_args()
99
+
100
+ eval_model(args)
model_vqa_loader.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+
4
+ import torch
5
+ import os
6
+ import json
7
+ from tqdm import tqdm
8
+ import shortuuid
9
+
10
+ from tinyllava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
11
+ from tinyllava.conversation import conv_templates, SeparatorStyle
12
+ from tinyllava.model.builder import load_pretrained_model
13
+ from tinyllava.utils import disable_torch_init
14
+ from tinyllava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path, KeywordsStoppingCriteria
15
+ from torch.utils.data import Dataset, DataLoader
16
+
17
+ from PIL import Image
18
+ import math
19
+
20
+
21
+ def split_list(lst, n):
22
+ """Split a list into n (roughly) equal-sized chunks"""
23
+ chunk_size = math.ceil(len(lst) / n) # integer division
24
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
25
+
26
+
27
+ def get_chunk(lst, n, k):
28
+ chunks = split_list(lst, n)
29
+ return chunks[k]
30
+
31
+
32
+ # Custom dataset class
33
+ class CustomDataset(Dataset):
34
+ def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
35
+ self.questions = questions
36
+ self.image_folder = image_folder
37
+ self.tokenizer = tokenizer
38
+ self.image_processor = image_processor
39
+ self.model_config = model_config
40
+
41
+ def __getitem__(self, index):
42
+ line = self.questions[index]
43
+ image_file = line["image"]
44
+ qs = line["text"]
45
+ if self.model_config.mm_use_im_start_end:
46
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
47
+ else:
48
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
49
+
50
+ conv = conv_templates[args.conv_mode].copy()
51
+ conv.append_message(conv.roles[0], qs)
52
+ conv.append_message(conv.roles[1], None)
53
+ prompt = conv.get_prompt()
54
+ #print(prompt)
55
+ image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
56
+ image_tensor = process_images([image], self.image_processor, self.model_config)[0]
57
+
58
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
59
+
60
+ return input_ids, image_tensor, image.size
61
+
62
+ def __len__(self):
63
+ return len(self.questions)
64
+
65
+
66
+ def collate_fn(batch):
67
+ input_ids, image_tensors, image_sizes = zip(*batch)
68
+ input_ids = torch.stack(input_ids, dim=0)
69
+ image_tensors = torch.stack(image_tensors, dim=0)
70
+ return input_ids, image_tensors, image_sizes
71
+
72
+
73
+ # DataLoader
74
+ def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
75
+ assert batch_size == 1, "batch_size must be 1"
76
+ dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
77
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
78
+ return data_loader
79
+
80
+
81
+ def eval_model(args):
82
+ # Model
83
+ disable_torch_init()
84
+ model_path = os.path.expanduser(args.model_path)
85
+ model_name = get_model_name_from_path(model_path)
86
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
87
+
88
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
89
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
90
+ answers_file = os.path.expanduser(args.answers_file)
91
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
92
+ ans_file = open(answers_file, "w")
93
+
94
+ if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
95
+ args.conv_mode = args.conv_mode + '_mmtag'
96
+ print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
97
+
98
+ data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
99
+ # print("Tokenizer's eos token: ", tokenizer.eos_token)
100
+ for (input_ids, image_tensor, image_sizes), line in tqdm(zip(data_loader, questions), total=len(questions)):
101
+ idx = line["question_id"]
102
+ cur_prompt = line["text"]
103
+ # keywords = [tokenizer.eos_token]
104
+ # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
105
+ input_ids = input_ids.to(device='cuda', non_blocking=True)
106
+ with torch.inference_mode():
107
+ output_ids = model.generate(
108
+ input_ids,
109
+ images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
110
+ pad_token_id=tokenizer.pad_token_id,
111
+ do_sample=True if args.temperature > 0 else False,
112
+ temperature=args.temperature,
113
+ top_p=args.top_p,
114
+ num_beams=args.num_beams,
115
+ max_new_tokens=args.max_new_tokens,
116
+ # stopping_criteria=[stopping_criteria],
117
+ use_cache=True)
118
+
119
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
120
+ # print("Printing outputs")
121
+ # print(outputs)
122
+ # time.sleep(5)
123
+ ans_id = shortuuid.uuid()
124
+ ans_file.write(json.dumps({"question_id": idx,
125
+ "prompt": cur_prompt,
126
+ "text": outputs,
127
+ "answer_id": ans_id,
128
+ "model_id": model_name,
129
+ "metadata": {}}) + "\n")
130
+ # ans_file.flush()
131
+ ans_file.close()
132
+
133
+ if __name__ == "__main__":
134
+ parser = argparse.ArgumentParser()
135
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
136
+ parser.add_argument("--model-base", type=str, default=None)
137
+ parser.add_argument("--image-folder", type=str, default="")
138
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
139
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
140
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
141
+ parser.add_argument("--num-chunks", type=int, default=1)
142
+ parser.add_argument("--chunk-idx", type=int, default=0)
143
+ parser.add_argument("--temperature", type=float, default=0.2)
144
+ parser.add_argument("--top_p", type=float, default=None)
145
+ parser.add_argument("--num_beams", type=int, default=1)
146
+ parser.add_argument("--max_new_tokens", type=int, default=128)
147
+ args = parser.parse_args()
148
+
149
+ eval_model(args)
model_vqa_mmbench.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ import pandas as pd
6
+ from tqdm import tqdm
7
+ import shortuuid
8
+
9
+ from tinyllava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
10
+ from tinyllava.conversation import conv_templates, SeparatorStyle
11
+ from tinyllava.model.builder import load_pretrained_model
12
+ from tinyllava.utils import disable_torch_init
13
+ from tinyllava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path
14
+
15
+ from PIL import Image
16
+ import math
17
+
18
+ all_options = ['A', 'B', 'C', 'D']
19
+
20
+
21
+ def split_list(lst, n):
22
+ """Split a list into n (roughly) equal-sized chunks"""
23
+ chunk_size = math.ceil(len(lst) / n) # integer division
24
+ return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
25
+
26
+
27
+ def get_chunk(lst, n, k):
28
+ chunks = split_list(lst, n)
29
+ return chunks[k]
30
+
31
+
32
+ def is_none(value):
33
+ if value is None:
34
+ return True
35
+ if type(value) is float and math.isnan(value):
36
+ return True
37
+ if type(value) is str and value.lower() == 'nan':
38
+ return True
39
+ if type(value) is str and value.lower() == 'none':
40
+ return True
41
+ return False
42
+
43
+
44
+ def get_options(row, options):
45
+ parsed_options = []
46
+ for option in options:
47
+ option_value = row[option]
48
+ if is_none(option_value):
49
+ break
50
+ parsed_options.append(option_value)
51
+ return parsed_options
52
+
53
+
54
+ def eval_model(args):
55
+ # Model
56
+ disable_torch_init()
57
+ model_path = os.path.expanduser(args.model_path)
58
+ model_name = get_model_name_from_path(model_path)
59
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
60
+
61
+ questions = pd.read_table(os.path.expanduser(args.question_file))
62
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
63
+ answers_file = os.path.expanduser(args.answers_file)
64
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
65
+ ans_file = open(answers_file, "w")
66
+
67
+ if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
68
+ args.conv_mode = args.conv_mode + '_mmtag'
69
+ print(
70
+ f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
71
+
72
+ for index, row in tqdm(questions.iterrows(), total=len(questions)):
73
+ options = get_options(row, all_options)
74
+ cur_option_char = all_options[:len(options)]
75
+
76
+ if args.all_rounds:
77
+ num_rounds = len(options)
78
+ else:
79
+ num_rounds = 1
80
+
81
+ for round_idx in range(num_rounds):
82
+ idx = row['index']
83
+ question = row['question']
84
+ hint = row['hint']
85
+ image = load_image_from_base64(row['image'])
86
+ if not is_none(hint):
87
+ question = hint + '\n' + question
88
+ for option_char, option in zip(all_options[:len(options)], options):
89
+ question = question + '\n' + option_char + '. ' + option
90
+ qs = cur_prompt = question
91
+ if model.config.mm_use_im_start_end:
92
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
93
+ else:
94
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
95
+
96
+ if args.single_pred_prompt:
97
+ if args.lang == 'cn':
98
+ qs = qs + '\n' + "请直接回答选项字母。"
99
+ else:
100
+ qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
101
+
102
+ conv = conv_templates[args.conv_mode].copy()
103
+ conv.append_message(conv.roles[0], qs)
104
+ conv.append_message(conv.roles[1], None)
105
+ prompt = conv.get_prompt()
106
+
107
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
108
+ 0).cuda()
109
+
110
+ image_tensor = process_images([image], image_processor, model.config)[0]
111
+
112
+ with torch.inference_mode():
113
+ output_ids = model.generate(
114
+ input_ids,
115
+ images=image_tensor.unsqueeze(0).half().cuda(),
116
+ do_sample=True if args.temperature > 0 else False,
117
+ temperature=args.temperature,
118
+ top_p=args.top_p,
119
+ num_beams=args.num_beams,
120
+ # no_repeat_ngram_size=3,
121
+ max_new_tokens=1024,
122
+ use_cache=True)
123
+
124
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
125
+
126
+ ans_id = shortuuid.uuid()
127
+ ans_file.write(json.dumps({"question_id": idx,
128
+ "round_id": round_idx,
129
+ "prompt": cur_prompt,
130
+ "text": outputs,
131
+ "options": options,
132
+ "option_char": cur_option_char,
133
+ "answer_id": ans_id,
134
+ "model_id": model_name,
135
+ "metadata": {}}) + "\n")
136
+ ans_file.flush()
137
+
138
+ # rotate options
139
+ options = options[1:] + options[:1]
140
+ cur_option_char = cur_option_char[1:] + cur_option_char[:1]
141
+ ans_file.close()
142
+
143
+
144
+ if __name__ == "__main__":
145
+ parser = argparse.ArgumentParser()
146
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
147
+ parser.add_argument("--model-base", type=str, default=None)
148
+ parser.add_argument("--image-folder", type=str, default="")
149
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
150
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
151
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
152
+ parser.add_argument("--num-chunks", type=int, default=1)
153
+ parser.add_argument("--chunk-idx", type=int, default=0)
154
+ parser.add_argument("--temperature", type=float, default=0.2)
155
+ parser.add_argument("--top_p", type=float, default=None)
156
+ parser.add_argument("--num_beams", type=int, default=1)
157
+ parser.add_argument("--all-rounds", action="store_true")
158
+ parser.add_argument("--single-pred-prompt", action="store_true")
159
+ parser.add_argument("--lang", type=str, default="en")
160
+ args = parser.parse_args()
161
+
162
+ eval_model(args)
model_vqa_science.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from tinyllava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from tinyllava.conversation import conv_templates, SeparatorStyle
10
+ from tinyllava.model.builder import load_pretrained_model
11
+ from tinyllava.utils import disable_torch_init
12
+ from tinyllava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13
+
14
+ from PIL import Image
15
+ import math
16
+
17
+
18
+ def split_list(lst, n):
19
+ """Split a list into n (roughly) equal-sized chunks"""
20
+ chunk_size = math.ceil(len(lst) / n) # integer division
21
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22
+
23
+
24
+ def get_chunk(lst, n, k):
25
+ chunks = split_list(lst, n)
26
+ return chunks[k]
27
+
28
+
29
+ def eval_model(args):
30
+ # Model
31
+ disable_torch_init()
32
+ model_path = os.path.expanduser(args.model_path)
33
+ model_name = get_model_name_from_path(model_path)
34
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
35
+
36
+ questions = json.load(open(os.path.expanduser(args.question_file), "r"))
37
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
38
+ answers_file = os.path.expanduser(args.answers_file)
39
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
40
+ ans_file = open(answers_file, "w")
41
+ for i, line in enumerate(tqdm(questions)):
42
+ idx = line["id"]
43
+ question = line['conversations'][0]
44
+ qs = question['value'].replace('<image>', '').strip()
45
+ cur_prompt = qs
46
+
47
+ if 'image' in line:
48
+ image_file = line["image"]
49
+ image = Image.open(os.path.join(args.image_folder, image_file))
50
+ image_tensor = process_images([image], image_processor, model.config)[0]
51
+ images = image_tensor.unsqueeze(0).half().cuda()
52
+
53
+ if getattr(model.config, 'mm_use_im_start_end', False):
54
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
55
+ else:
56
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
57
+ cur_prompt = '<image>' + '\n' + cur_prompt
58
+ else:
59
+ images = None
60
+
61
+
62
+ if args.single_pred_prompt:
63
+ qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
64
+ cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
65
+
66
+ conv = conv_templates[args.conv_mode].copy()
67
+ conv.append_message(conv.roles[0], qs)
68
+ conv.append_message(conv.roles[1], None)
69
+ prompt = conv.get_prompt()
70
+
71
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
72
+
73
+ with torch.inference_mode():
74
+ output_ids = model.generate(
75
+ input_ids,
76
+ images=images,
77
+ do_sample=True if args.temperature > 0 else False,
78
+ temperature=args.temperature,
79
+ max_new_tokens=1024,
80
+ use_cache=True,
81
+
82
+ )
83
+
84
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
85
+ ans_id = shortuuid.uuid()
86
+ ans_file.write(json.dumps({"question_id": idx,
87
+ "prompt": cur_prompt,
88
+ "text": outputs,
89
+ "answer_id": ans_id,
90
+ "model_id": model_name,
91
+ "metadata": {}}) + "\n")
92
+ ans_file.flush()
93
+
94
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
95
+ ans_id = shortuuid.uuid()
96
+ ans_file.write(json.dumps({"question_id": idx,
97
+ "prompt": cur_prompt,
98
+ "text": outputs,
99
+ "answer_id": ans_id,
100
+ "model_id": model_name,
101
+ "metadata": {}}) + "\n")
102
+ ans_file.flush()
103
+ ans_file.close()
104
+
105
+ if __name__ == "__main__":
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
108
+ parser.add_argument("--model-base", type=str, default=None)
109
+ parser.add_argument("--image-folder", type=str, default="")
110
+ parser.add_argument("--question-file", type=str, default="tables/question.json")
111
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
112
+ parser.add_argument("--conv-mode", type=str, default="llava_v0")
113
+ parser.add_argument("--num-chunks", type=int, default=1)
114
+ parser.add_argument("--chunk-idx", type=int, default=0)
115
+ parser.add_argument("--temperature", type=float, default=0.2)
116
+ parser.add_argument("--answer-prompter", action="store_true")
117
+ parser.add_argument("--single-pred-prompt", action="store_true")
118
+ args = parser.parse_args()
119
+
120
+ eval_model(args)
121
+
122
+
model_worker.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import asyncio
6
+ import json
7
+ import time
8
+ import threading
9
+ import uuid
10
+
11
+ from fastapi import FastAPI, Request, BackgroundTasks
12
+ from fastapi.responses import StreamingResponse
13
+ import requests
14
+ import torch
15
+ import uvicorn
16
+ from functools import partial
17
+
18
+ from tinyllava.constants import WORKER_HEART_BEAT_INTERVAL
19
+ from tinyllava.utils import (build_logger, server_error_msg,
20
+ pretty_print_semaphore)
21
+ from tinyllava.model.builder import load_pretrained_model
22
+ from tinyllava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
23
+ from tinyllava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+ from transformers import TextIteratorStreamer
25
+ from threading import Thread
26
+
27
+
28
+ GB = 1 << 30
29
+
30
+ worker_id = str(uuid.uuid4())[:6]
31
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
32
+ global_counter = 0
33
+
34
+ model_semaphore = None
35
+
36
+
37
+ def heart_beat_worker(controller):
38
+
39
+ while True:
40
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
41
+ controller.send_heart_beat()
42
+
43
+
44
+ class ModelWorker:
45
+ def __init__(self, controller_addr, worker_addr,
46
+ worker_id, no_register,
47
+ model_path, model_base, model_name,
48
+ load_8bit, load_4bit, device):
49
+ self.controller_addr = controller_addr
50
+ self.worker_addr = worker_addr
51
+ self.worker_id = worker_id
52
+ if model_path.endswith("/"):
53
+ model_path = model_path[:-1]
54
+ if model_name is None:
55
+ model_paths = model_path.split("/")
56
+ if model_paths[-1].startswith('checkpoint-'):
57
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
58
+ else:
59
+ self.model_name = model_paths[-1]
60
+ else:
61
+ self.model_name = model_name
62
+
63
+ self.device = device
64
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
65
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
66
+ model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
67
+ self.is_multimodal = 'tinyllava' in self.model_name.lower()
68
+
69
+ if not no_register:
70
+ self.register_to_controller()
71
+ self.heart_beat_thread = threading.Thread(
72
+ target=heart_beat_worker, args=(self,))
73
+ self.heart_beat_thread.start()
74
+
75
+ def register_to_controller(self):
76
+ logger.info("Register to controller")
77
+
78
+ url = self.controller_addr + "/register_worker"
79
+ data = {
80
+ "worker_name": self.worker_addr,
81
+ "check_heart_beat": True,
82
+ "worker_status": self.get_status()
83
+ }
84
+ r = requests.post(url, json=data)
85
+ assert r.status_code == 200
86
+
87
+ def send_heart_beat(self):
88
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
89
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
90
+ f"global_counter: {global_counter}")
91
+
92
+ url = self.controller_addr + "/receive_heart_beat"
93
+
94
+ while True:
95
+ try:
96
+ ret = requests.post(url, json={
97
+ "worker_name": self.worker_addr,
98
+ "queue_length": self.get_queue_length()}, timeout=5)
99
+ exist = ret.json()["exist"]
100
+ break
101
+ except requests.exceptions.RequestException as e:
102
+ logger.error(f"heart beat error: {e}")
103
+ time.sleep(5)
104
+
105
+ if not exist:
106
+ self.register_to_controller()
107
+
108
+ def get_queue_length(self):
109
+ if model_semaphore is None:
110
+ return 0
111
+ else:
112
+ return args.limit_model_concurrency - model_semaphore._value + (len(
113
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
114
+
115
+ def get_status(self):
116
+ return {
117
+ "model_names": [self.model_name],
118
+ "speed": 1,
119
+ "queue_length": self.get_queue_length(),
120
+ }
121
+
122
+ @torch.inference_mode()
123
+ def generate_stream(self, params):
124
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
125
+
126
+ prompt = params["prompt"]
127
+ ori_prompt = prompt
128
+ images = params.get("images", None)
129
+ num_image_tokens = 0
130
+ if images is not None and len(images) > 0 and self.is_multimodal:
131
+ if len(images) > 0:
132
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
133
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
134
+
135
+ images = [load_image_from_base64(image) for image in images]
136
+ images = process_images(images, image_processor, model.config)
137
+
138
+ if type(images) is list:
139
+ images = [image.to(self.model.device, dtype=torch.float16) for image in images]
140
+ else:
141
+ images = images.to(self.model.device, dtype=torch.float16)
142
+
143
+ replace_token = DEFAULT_IMAGE_TOKEN
144
+ if getattr(self.model.config, 'mm_use_im_start_end', False):
145
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
146
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
147
+
148
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
149
+ else:
150
+ images = None
151
+ image_args = {"images": images}
152
+ else:
153
+ images = None
154
+ image_args = {}
155
+
156
+ temperature = float(params.get("temperature", 1.0))
157
+ top_p = float(params.get("top_p", 1.0))
158
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
159
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
160
+ stop_str = params.get("stop", None)
161
+ do_sample = True if temperature > 0.001 else False
162
+
163
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
164
+ keywords = [stop_str]
165
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
166
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
167
+
168
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
169
+
170
+ if max_new_tokens < 1:
171
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
172
+ return
173
+
174
+ thread = Thread(target=model.generate, kwargs=dict(
175
+ inputs=input_ids,
176
+ do_sample=do_sample,
177
+ temperature=temperature,
178
+ top_p=top_p,
179
+ max_new_tokens=max_new_tokens,
180
+ streamer=streamer,
181
+ stopping_criteria=[stopping_criteria],
182
+ use_cache=True,
183
+ **image_args
184
+ ))
185
+ thread.start()
186
+
187
+ generated_text = ori_prompt
188
+ for new_text in streamer:
189
+ generated_text += new_text
190
+ if generated_text.endswith(stop_str):
191
+ generated_text = generated_text[:-len(stop_str)]
192
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
193
+
194
+ def generate_stream_gate(self, params):
195
+ try:
196
+ for x in self.generate_stream(params):
197
+ yield x
198
+ except ValueError as e:
199
+ print("Caught ValueError:", e)
200
+ ret = {
201
+ "text": server_error_msg,
202
+ "error_code": 1,
203
+ }
204
+ yield json.dumps(ret).encode() + b"\0"
205
+ except torch.cuda.CudaError as e:
206
+ print("Caught torch.cuda.CudaError:", e)
207
+ ret = {
208
+ "text": server_error_msg,
209
+ "error_code": 1,
210
+ }
211
+ yield json.dumps(ret).encode() + b"\0"
212
+ except Exception as e:
213
+ print("Caught Unknown Error", e)
214
+ ret = {
215
+ "text": server_error_msg,
216
+ "error_code": 1,
217
+ }
218
+ yield json.dumps(ret).encode() + b"\0"
219
+
220
+
221
+ app = FastAPI()
222
+
223
+
224
+ def release_model_semaphore(fn=None):
225
+ model_semaphore.release()
226
+ if fn is not None:
227
+ fn()
228
+
229
+
230
+ @app.post("/worker_generate_stream")
231
+ async def generate_stream(request: Request):
232
+ global model_semaphore, global_counter
233
+ global_counter += 1
234
+ params = await request.json()
235
+
236
+ if model_semaphore is None:
237
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
238
+ await model_semaphore.acquire()
239
+ worker.send_heart_beat()
240
+ generator = worker.generate_stream_gate(params)
241
+ background_tasks = BackgroundTasks()
242
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
243
+ return StreamingResponse(generator, background=background_tasks)
244
+
245
+
246
+ @app.post("/worker_get_status")
247
+ async def get_status(request: Request):
248
+ return worker.get_status()
249
+
250
+
251
+ if __name__ == "__main__":
252
+ parser = argparse.ArgumentParser()
253
+ parser.add_argument("--host", type=str, default="localhost")
254
+ parser.add_argument("--port", type=int, default=21002)
255
+ parser.add_argument("--worker-address", type=str,
256
+ default="http://localhost:21002")
257
+ parser.add_argument("--controller-address", type=str,
258
+ default="http://localhost:21001")
259
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
260
+ parser.add_argument("--model-base", type=str, default=None)
261
+ parser.add_argument("--model-name", type=str)
262
+ parser.add_argument("--device", type=str, default="cuda")
263
+ parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `tinyllava` is included in the model path.")
264
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
265
+ parser.add_argument("--stream-interval", type=int, default=1)
266
+ parser.add_argument("--no-register", action="store_true")
267
+ parser.add_argument("--load-8bit", action="store_true")
268
+ parser.add_argument("--load-4bit", action="store_true")
269
+ args = parser.parse_args()
270
+ logger.info(f"args: {args}")
271
+
272
+ if args.multi_modal:
273
+ logger.warning("Multimodal mode is automatically detected with model name, please make sure `tinyllava` is included in the model path.")
274
+
275
+ worker = ModelWorker(args.controller_address,
276
+ args.worker_address,
277
+ worker_id,
278
+ args.no_register,
279
+ args.model_path,
280
+ args.model_base,
281
+ args.model_name,
282
+ args.load_8bit,
283
+ args.load_4bit,
284
+ args.device)
285
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
modeling_attn_mask_utils.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+
19
+
20
+ @dataclass
21
+ class AttentionMaskConverter:
22
+ """
23
+ A utility attention mask class that allows one to:
24
+ - Create a causal 4d mask
25
+ - Create a causal 4d mask with slided window
26
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
27
+ key_value_length) that can be multiplied with attention scores
28
+
29
+ Examples:
30
+
31
+ ```python
32
+ >>> import torch
33
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
34
+
35
+ >>> converter = AttentionMaskConverter(True)
36
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
37
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
38
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
39
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
40
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
41
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
42
+ ```
43
+
44
+ Parameters:
45
+ is_causal (`bool`):
46
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
47
+
48
+ sliding_window (`int`, *optional*):
49
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
50
+ """
51
+
52
+ is_causal: bool
53
+ sliding_window: int
54
+
55
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
56
+ self.is_causal = is_causal
57
+ self.sliding_window = sliding_window
58
+
59
+ if self.sliding_window is not None and self.sliding_window <= 0:
60
+ raise ValueError(
61
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
62
+ )
63
+
64
+ def to_causal_4d(
65
+ self,
66
+ batch_size: int,
67
+ query_length: int,
68
+ key_value_length: int,
69
+ dtype: torch.dtype,
70
+ device: Union[torch.device, "str"] = "cpu",
71
+ ) -> Optional[torch.Tensor]:
72
+ """
73
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
74
+ bias to upper right hand triangular matrix (causal mask).
75
+ """
76
+ if not self.is_causal:
77
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
78
+
79
+ # If shape is not cached, create a new causal mask and cache it
80
+ input_shape = (batch_size, query_length)
81
+ past_key_values_length = key_value_length - query_length
82
+
83
+ # create causal mask
84
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
85
+ causal_4d_mask = None
86
+ if input_shape[-1] > 1 or self.sliding_window is not None:
87
+ causal_4d_mask = self._make_causal_mask(
88
+ input_shape,
89
+ dtype,
90
+ device=device,
91
+ past_key_values_length=past_key_values_length,
92
+ sliding_window=self.sliding_window,
93
+ )
94
+
95
+ return causal_4d_mask
96
+
97
+ def to_4d(
98
+ self,
99
+ attention_mask_2d: torch.Tensor,
100
+ query_length: int,
101
+ dtype: torch.dtype,
102
+ key_value_length: Optional[int] = None,
103
+ ) -> torch.Tensor:
104
+ """
105
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
106
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
107
+ causal, a causal mask will be added.
108
+ """
109
+ input_shape = (attention_mask_2d.shape[0], query_length)
110
+
111
+ # create causal mask
112
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
113
+ causal_4d_mask = None
114
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
115
+ if key_value_length is None:
116
+ raise ValueError(
117
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
118
+ )
119
+
120
+ past_key_values_length = key_value_length - query_length
121
+ causal_4d_mask = self._make_causal_mask(
122
+ input_shape,
123
+ dtype,
124
+ device=attention_mask_2d.device,
125
+ past_key_values_length=past_key_values_length,
126
+ sliding_window=self.sliding_window,
127
+ )
128
+ elif self.sliding_window is not None:
129
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
130
+
131
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
132
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
133
+ attention_mask_2d.device
134
+ )
135
+ if causal_4d_mask is not None:
136
+ expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
137
+
138
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
139
+ expanded_4d_mask = expanded_attn_mask
140
+
141
+ return expanded_4d_mask
142
+
143
+ @staticmethod
144
+ def _make_causal_mask(
145
+ input_ids_shape: torch.Size,
146
+ dtype: torch.dtype,
147
+ device: torch.device,
148
+ past_key_values_length: int = 0,
149
+ sliding_window: Optional[int] = None,
150
+ ):
151
+ """
152
+ Make causal mask used for bi-directional self-attention.
153
+ """
154
+ bsz, tgt_len = input_ids_shape
155
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
156
+ mask_cond = torch.arange(mask.size(-1), device=device)
157
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
158
+
159
+ mask = mask.to(dtype)
160
+
161
+ if past_key_values_length > 0:
162
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
163
+
164
+ # add lower triangular sliding window mask if necessary
165
+ if sliding_window is not None:
166
+ diagonal = past_key_values_length - sliding_window + 1
167
+
168
+ context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
169
+ mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
170
+
171
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
172
+
173
+ @staticmethod
174
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
175
+ """
176
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
177
+ """
178
+ bsz, src_len = mask.size()
179
+ tgt_len = tgt_len if tgt_len is not None else src_len
180
+
181
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
182
+
183
+ inverted_mask = 1.0 - expanded_mask
184
+
185
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
186
+
187
+ @staticmethod
188
+ def _unmask_unattended(
189
+ expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
190
+ ):
191
+ # fmt: off
192
+ """
193
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
194
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
195
+ Details: https://github.com/pytorch/pytorch/issues/110213
196
+
197
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
198
+ `attention_mask` is [bsz, src_seq_len].
199
+
200
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
201
+
202
+ For example, if `attention_mask` is
203
+ ```
204
+ [[0, 0, 1],
205
+ [1, 1, 1],
206
+ [0, 1, 1]]
207
+ ```
208
+ and `expanded_mask` is (e.g. here left-padding case)
209
+ ```
210
+ [[[[0, 0, 0],
211
+ [0, 0, 0],
212
+ [0, 0, 1]]],
213
+ [[[1, 0, 0],
214
+ [1, 1, 0],
215
+ [1, 1, 1]]],
216
+ [[[0, 0, 0],
217
+ [0, 1, 0],
218
+ [0, 1, 1]]]]
219
+ ```
220
+ then the modified `expanded_mask` will be
221
+ ```
222
+ [[[[1, 1, 1], <-- modified
223
+ [1, 1, 1], <-- modified
224
+ [0, 0, 1]]],
225
+ [[[1, 0, 0],
226
+ [1, 1, 0],
227
+ [1, 1, 1]]],
228
+ [[[1, 1, 1], <-- modified
229
+ [0, 1, 0],
230
+ [0, 1, 1]]]]
231
+ ```
232
+ """
233
+ # fmt: on
234
+
235
+ # Get the index of the first non-zero value for every sample in the batch.
236
+ # In the above example, indices = [[2], [0], [1]]]
237
+ tmp = torch.arange(attention_mask.shape[1], 0, -1)
238
+ indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
239
+
240
+ # Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
241
+ # expanded mask will be completely unattended.
242
+ left_masked_rows = torch.where(indices > 0)[0]
243
+
244
+ if left_masked_rows.shape[0] == 0:
245
+ return expanded_mask
246
+ indices = indices[left_masked_rows]
247
+
248
+ max_len = torch.max(indices)
249
+ range_tensor = torch.arange(max_len).unsqueeze(0)
250
+ range_tensor = range_tensor.repeat(indices.size(0), 1)
251
+
252
+ # Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
253
+ range_tensor[range_tensor >= indices] = 0
254
+
255
+ # TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
256
+ if expanded_mask.dim() == 4:
257
+ num_masks = expanded_mask.shape[1]
258
+ if num_masks == 1:
259
+ # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
260
+ mask_slice = (left_masked_rows[:, None], 0, range_tensor)
261
+ else:
262
+ # Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
263
+ mask_slice = (
264
+ left_masked_rows[:, None, None],
265
+ torch.arange(num_masks)[None, :, None],
266
+ range_tensor[:, None, :],
267
+ )
268
+ else:
269
+ # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
270
+ mask_slice = (left_masked_rows[:, None], range_tensor)
271
+
272
+ expanded_mask[mask_slice] = unmasked_value
273
+
274
+ return expanded_mask
275
+
276
+
277
+ def _prepare_4d_causal_attention_mask(
278
+ attention_mask: Optional[torch.Tensor],
279
+ input_shape: Union[torch.Size, Tuple, List],
280
+ inputs_embeds: torch.Tensor,
281
+ past_key_values_length: int,
282
+ sliding_window: Optional[int] = None,
283
+ ):
284
+ """
285
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
286
+ `(batch_size, key_value_length)`
287
+
288
+ Args:
289
+ attention_mask (`torch.Tensor` or `None`):
290
+ A 2D attention mask of shape `(batch_size, key_value_length)`
291
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
292
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
293
+ inputs_embeds (`torch.Tensor`):
294
+ The embedded inputs as a torch Tensor.
295
+ past_key_values_length (`int`):
296
+ The length of the key value cache.
297
+ sliding_window (`int`, *optional*):
298
+ If the model uses windowed attention, a sliding window should be passed.
299
+ """
300
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
301
+
302
+ key_value_length = input_shape[-1] + past_key_values_length
303
+
304
+ # 4d mask is passed through the layers
305
+ if attention_mask is not None and len(attention_mask.shape) == 2:
306
+ attention_mask = attn_mask_converter.to_4d(
307
+ attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
308
+ )
309
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
310
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
311
+ if tuple(attention_mask.shape) != expected_shape:
312
+ raise ValueError(
313
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
314
+ )
315
+ else:
316
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
317
+ inverted_mask = 1.0 - attention_mask
318
+ attention_mask = inverted_mask.masked_fill(
319
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
320
+ )
321
+ else:
322
+ attention_mask = attn_mask_converter.to_causal_4d(
323
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
324
+ )
325
+
326
+ return attention_mask
327
+
328
+
329
+ # Adapted from _prepare_4d_causal_attention_mask
330
+ def _prepare_4d_causal_attention_mask_for_sdpa(
331
+ attention_mask: Optional[torch.Tensor],
332
+ input_shape: Union[torch.Size, Tuple, List],
333
+ inputs_embeds: torch.Tensor,
334
+ past_key_values_length: int,
335
+ sliding_window: Optional[int] = None,
336
+ ):
337
+ """
338
+ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
339
+
340
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
341
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
342
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
343
+ """
344
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
345
+
346
+ key_value_length = input_shape[-1] + past_key_values_length
347
+ batch_size, query_length = input_shape
348
+
349
+ # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
350
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
351
+ # TODO: Fix this as well when using torchdynamo with fullgraph=True.
352
+ is_tracing = torch.jit.is_tracing()
353
+
354
+ if attention_mask is not None:
355
+ # 4d mask is passed through
356
+ if len(attention_mask.shape) == 4:
357
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
358
+ if tuple(attention_mask.shape) != expected_shape:
359
+ raise ValueError(
360
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
361
+ )
362
+ else:
363
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
364
+ inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
365
+ attention_mask = inverted_mask.masked_fill(
366
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
367
+ )
368
+ return attention_mask
369
+
370
+ elif torch.all(attention_mask == 1):
371
+ if is_tracing:
372
+ pass
373
+ elif query_length == 1:
374
+ # For query_length == 1, causal attention and bi-directional attention are the same.
375
+ attention_mask = None
376
+ elif key_value_length == query_length:
377
+ attention_mask = None
378
+ else:
379
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
380
+ # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
381
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
382
+ pass
383
+ elif query_length > 1 and key_value_length != query_length:
384
+ # See the comment above (https://github.com/pytorch/pytorch/issues/108108).
385
+ # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
386
+ attention_mask = True
387
+ elif is_tracing:
388
+ raise ValueError(
389
+ 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
390
+ )
391
+
392
+ if attention_mask is None:
393
+ expanded_4d_mask = None
394
+ elif attention_mask is True:
395
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
396
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
397
+ )
398
+ else:
399
+ expanded_4d_mask = attn_mask_converter.to_4d(
400
+ attention_mask,
401
+ input_shape[-1],
402
+ dtype=inputs_embeds.dtype,
403
+ key_value_length=key_value_length,
404
+ )
405
+
406
+ # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
407
+ # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
408
+ if query_length > 1:
409
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
410
+ expanded_4d_mask, attention_mask, unmasked_value=0.0
411
+ )
412
+
413
+ return expanded_4d_mask
414
+
415
+
416
+ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
417
+ """
418
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
419
+ `(batch_size, key_value_length)`
420
+
421
+ Args:
422
+ mask (`torch.Tensor` or `None`):
423
+ A 2D attention mask of shape `(batch_size, key_value_length)`
424
+ dtype (`torch.dtype`):
425
+ The torch dtype the created mask shall have.
426
+ tgt_len (`int`):
427
+ The target length or query length the created mask shall have.
428
+ """
429
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
430
+
431
+
432
+ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
433
+ """
434
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
435
+ `(batch_size, key_value_length)`
436
+
437
+ Args:
438
+ mask (`torch.Tensor` or `None`):
439
+ A 2D attention mask of shape `(batch_size, key_value_length)`
440
+ dtype (`torch.dtype`):
441
+ The torch dtype the created mask shall have.
442
+ tgt_len (`int`):
443
+ The target length or query length the created mask shall have.
444
+ """
445
+ batch_size, key_value_length = mask.shape
446
+ tgt_len = tgt_len if tgt_len is not None else key_value_length
447
+
448
+ # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
449
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
450
+ # TODO: Fix this as well when using torchdynamo with fullgraph=True.
451
+ is_tracing = torch.jit.is_tracing()
452
+
453
+ if torch.all(mask == 1):
454
+ if is_tracing:
455
+ pass
456
+ elif tgt_len == 1:
457
+ # For query_length == 1, causal attention and bi-directional attention are the same.
458
+ return None
459
+ elif key_value_length == tgt_len:
460
+ return None
461
+ else:
462
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation
463
+ # may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
464
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
465
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
466
+ else:
467
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
468
+
469
+
470
+ def _create_4d_causal_attention_mask(
471
+ input_shape: Union[torch.Size, Tuple, List],
472
+ dtype: torch.dtype,
473
+ device: torch.device,
474
+ past_key_values_length: int = 0,
475
+ sliding_window: Optional[int] = None,
476
+ ) -> Optional[torch.Tensor]:
477
+ """
478
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
479
+
480
+ Args:
481
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
482
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
483
+ dtype (`torch.dtype`):
484
+ The torch dtype the created mask shall have.
485
+ device (`int`):
486
+ The torch device the created mask shall have.
487
+ sliding_window (`int`, *optional*):
488
+ If the model uses windowed attention, a sliding window should be passed.
489
+ """
490
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
491
+
492
+ key_value_length = past_key_values_length + input_shape[-1]
493
+ attention_mask = attn_mask_converter.to_causal_4d(
494
+ input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
495
+ )
496
+
497
+ return attention_mask
modeling_phi.py ADDED
@@ -0,0 +1,1344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ PyTorch Phi model."""
17
+
18
+
19
+ import math
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache
30
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutputWithPast,
33
+ CausalLMOutputWithPast,
34
+ SequenceClassifierOutputWithPast,
35
+ TokenClassifierOutput,
36
+ )
37
+ from transformers.modeling_utils import PreTrainedModel
38
+ from transformers.utils import (
39
+ add_code_sample_docstrings,
40
+ add_start_docstrings,
41
+ add_start_docstrings_to_model_forward,
42
+ is_flash_attn_2_available,
43
+ is_flash_attn_greater_or_equal_2_10,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from .configuration_phi import PhiConfig
48
+
49
+
50
+ try:
51
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
52
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
53
+ except:
54
+ pass
55
+
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ _CHECKPOINT_FOR_DOC = "microsoft/phi-2"
60
+ _CONFIG_FOR_DOC = "PhiConfig"
61
+
62
+ PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "microsoft/phi-2",
64
+ # See all Phi models at https://huggingface.co/models?filter=phi
65
+ ]
66
+
67
+
68
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
69
+ def _get_unpad_data(attention_mask):
70
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
71
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
72
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
73
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
74
+ return (
75
+ indices,
76
+ cu_seqlens,
77
+ max_seqlen_in_batch,
78
+ )
79
+
80
+
81
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
82
+ class PhiRotaryEmbedding(nn.Module):
83
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
84
+ super().__init__()
85
+
86
+ self.dim = dim
87
+ self.max_position_embeddings = max_position_embeddings
88
+ self.base = base
89
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
90
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
91
+
92
+ # Build here to make `torch.jit.trace` work.
93
+ self._set_cos_sin_cache(
94
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
95
+ )
96
+
97
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
98
+ self.max_seq_len_cached = seq_len
99
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
100
+
101
+ freqs = torch.outer(t, self.inv_freq)
102
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
103
+ emb = torch.cat((freqs, freqs), dim=-1)
104
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
105
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
106
+
107
+ def forward(self, x, seq_len=None):
108
+ # x: [bs, num_attention_heads, seq_len, head_size]
109
+ if seq_len > self.max_seq_len_cached:
110
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
111
+
112
+ return (
113
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
114
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
115
+ )
116
+
117
+
118
+ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
119
+ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
120
+ """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
121
+
122
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
123
+ self.scaling_factor = scaling_factor
124
+ super().__init__(dim, max_position_embeddings, base, device)
125
+
126
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
127
+ self.max_seq_len_cached = seq_len
128
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
129
+ t = t / self.scaling_factor
130
+
131
+ freqs = torch.outer(t, self.inv_freq)
132
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
133
+ emb = torch.cat((freqs, freqs), dim=-1)
134
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
135
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
136
+
137
+
138
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
139
+ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
140
+ """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
141
+
142
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
143
+ self.scaling_factor = scaling_factor
144
+ super().__init__(dim, max_position_embeddings, base, device)
145
+
146
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
147
+ self.max_seq_len_cached = seq_len
148
+
149
+ if seq_len > self.max_position_embeddings:
150
+ base = self.base * (
151
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
152
+ ) ** (self.dim / (self.dim - 2))
153
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
154
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
155
+
156
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
157
+
158
+ freqs = torch.outer(t, self.inv_freq)
159
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
160
+ emb = torch.cat((freqs, freqs), dim=-1)
161
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
162
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
163
+
164
+
165
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
166
+ def rotate_half(x):
167
+ """Rotates half the hidden dims of the input."""
168
+ x1 = x[..., : x.shape[-1] // 2]
169
+ x2 = x[..., x.shape[-1] // 2 :]
170
+ return torch.cat((-x2, x1), dim=-1)
171
+
172
+
173
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
174
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
175
+ """Applies Rotary Position Embedding to the query and key tensors.
176
+ Args:
177
+ q (`torch.Tensor`): The query tensor.
178
+ k (`torch.Tensor`): The key tensor.
179
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
180
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
181
+ position_ids (`torch.Tensor`):
182
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
183
+ used to pass offsetted position ids when working with a KV-cache.
184
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
185
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
186
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
187
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
188
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
189
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
190
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
191
+ Returns:
192
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
193
+ """
194
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
195
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
196
+ q_embed = (q * cos) + (rotate_half(q) * sin)
197
+ k_embed = (k * cos) + (rotate_half(k) * sin)
198
+ return q_embed, k_embed
199
+
200
+
201
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
202
+ class PhiMLP(nn.Module):
203
+ def __init__(self, config):
204
+ super().__init__()
205
+ self.config = config
206
+ self.activation_fn = ACT2FN[config.hidden_act]
207
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
208
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
209
+
210
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
211
+ hidden_states = self.fc1(hidden_states)
212
+ hidden_states = self.activation_fn(hidden_states)
213
+ hidden_states = self.fc2(hidden_states)
214
+ return hidden_states
215
+
216
+
217
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
218
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
219
+ """
220
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
221
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
222
+ """
223
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
224
+ if n_rep == 1:
225
+ return hidden_states
226
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
227
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
228
+
229
+
230
+ class PhiAttention(nn.Module):
231
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
232
+
233
+ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
234
+ super().__init__()
235
+ self.config = config
236
+ self.layer_idx = layer_idx
237
+ if layer_idx is None:
238
+ logger.warning_once(
239
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
240
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
241
+ "when creating this class."
242
+ )
243
+
244
+ self.attention_dropout = config.attention_dropout
245
+ self.hidden_size = config.hidden_size
246
+ self.num_heads = config.num_attention_heads
247
+ self.head_dim = self.hidden_size // self.num_heads
248
+ self.num_key_value_heads = config.num_key_value_heads
249
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
250
+ self.max_position_embeddings = config.max_position_embeddings
251
+ self.rope_theta = config.rope_theta
252
+ self.partial_rotary_factor = config.partial_rotary_factor
253
+ self.is_causal = True
254
+
255
+ if (self.head_dim * self.num_heads) != self.hidden_size:
256
+ raise ValueError(
257
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
258
+ f" and `num_heads`: {self.num_heads})."
259
+ )
260
+
261
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
262
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
263
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
264
+ self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
265
+
266
+ self.qk_layernorm = config.qk_layernorm
267
+ if self.qk_layernorm:
268
+ self.q_layernorm = nn.LayerNorm(
269
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
270
+ )
271
+ self.k_layernorm = nn.LayerNorm(
272
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
273
+ )
274
+
275
+ self._init_rope()
276
+
277
+ def _init_rope(self):
278
+ if self.config.rope_scaling is None:
279
+ self.rotary_emb = PhiRotaryEmbedding(
280
+ int(self.partial_rotary_factor * self.head_dim),
281
+ max_position_embeddings=self.max_position_embeddings,
282
+ base=self.rope_theta,
283
+ )
284
+ else:
285
+ scaling_type = self.config.rope_scaling["type"]
286
+ scaling_factor = self.config.rope_scaling["factor"]
287
+ if scaling_type == "linear":
288
+ self.rotary_emb = PhiLinearScalingRotaryEmbedding(
289
+ int(self.partial_rotary_factor * self.head_dim),
290
+ max_position_embeddings=self.max_position_embeddings,
291
+ scaling_factor=scaling_factor,
292
+ base=self.rope_theta,
293
+ )
294
+ elif scaling_type == "dynamic":
295
+ self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
296
+ int(self.partial_rotary_factor * self.head_dim),
297
+ max_position_embeddings=self.max_position_embeddings,
298
+ scaling_factor=scaling_factor,
299
+ base=self.rope_theta,
300
+ )
301
+ else:
302
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
303
+
304
+ # Phi-2 has an attention overflow issue (with FP16) and requires autocast to be disabled
305
+ @torch.autocast("cpu", enabled=False)
306
+ @torch.autocast("cuda", enabled=False)
307
+ def forward(
308
+ self,
309
+ hidden_states: torch.Tensor,
310
+ attention_mask: Optional[torch.Tensor] = None,
311
+ position_ids: Optional[torch.LongTensor] = None,
312
+ past_key_value: Optional[Cache] = None,
313
+ output_attentions: bool = False,
314
+ use_cache: bool = False,
315
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
316
+ bsz, q_len, _ = hidden_states.size()
317
+
318
+ query_states = self.q_proj(hidden_states)
319
+ key_states = self.k_proj(hidden_states)
320
+ value_states = self.v_proj(hidden_states)
321
+
322
+ if self.qk_layernorm:
323
+ query_states = self.q_layernorm(query_states)
324
+ key_states = self.k_layernorm(key_states)
325
+
326
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
327
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
328
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
329
+
330
+ kv_seq_len = key_states.shape[-2]
331
+ if past_key_value is not None:
332
+ if self.layer_idx is None:
333
+ raise ValueError(
334
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
335
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
336
+ "with a layer index."
337
+ )
338
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
339
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
340
+
341
+ # Partial rotary embedding
342
+ query_rot, query_pass = (
343
+ query_states[..., : self.rotary_emb.dim],
344
+ query_states[..., self.rotary_emb.dim :],
345
+ )
346
+ key_rot, key_pass = (
347
+ key_states[..., : self.rotary_emb.dim],
348
+ key_states[..., self.rotary_emb.dim :],
349
+ )
350
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
351
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
352
+
353
+ # [batch_size, seq_length, num_heads, head_dim]
354
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
355
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
356
+
357
+ if past_key_value is not None:
358
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
359
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
360
+
361
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
362
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
363
+
364
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
365
+ attn_weights = torch.matmul(
366
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
367
+ ) / math.sqrt(self.head_dim)
368
+
369
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
370
+ raise ValueError(
371
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
372
+ f" {attn_weights.size()}"
373
+ )
374
+
375
+ if attention_mask is not None:
376
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
377
+ raise ValueError(
378
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
379
+ )
380
+ attn_weights = attn_weights + attention_mask
381
+
382
+ # upcast attention to fp32
383
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
384
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
385
+
386
+ attn_output = torch.matmul(attn_weights, value_states)
387
+
388
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
389
+ raise ValueError(
390
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
391
+ f" {attn_output.size()}"
392
+ )
393
+
394
+ attn_output = attn_output.transpose(1, 2).contiguous()
395
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
396
+
397
+ attn_output = self.dense(attn_output)
398
+
399
+ if not output_attentions:
400
+ attn_weights = None
401
+
402
+ return attn_output, attn_weights, past_key_value
403
+
404
+
405
+ class PhiFlashAttention2(PhiAttention):
406
+ """
407
+ Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
408
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
409
+ flash attention and deal with padding tokens in case the input contains any of them.
410
+ """
411
+
412
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
413
+ def __init__(self, *args, **kwargs):
414
+ super().__init__(*args, **kwargs)
415
+
416
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
417
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
418
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
419
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states: torch.Tensor,
424
+ attention_mask: Optional[torch.LongTensor] = None,
425
+ position_ids: Optional[torch.LongTensor] = None,
426
+ past_key_value: Optional[Cache] = None,
427
+ output_attentions: bool = False,
428
+ use_cache: bool = False,
429
+ **kwargs,
430
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
431
+ # PhiFlashAttention2 attention does not support output_attentions
432
+
433
+ output_attentions = False
434
+
435
+ bsz, q_len, _ = hidden_states.size()
436
+
437
+ query_states = self.q_proj(hidden_states)
438
+ key_states = self.k_proj(hidden_states)
439
+ value_states = self.v_proj(hidden_states)
440
+
441
+ if self.qk_layernorm:
442
+ query_states = self.q_layernorm(query_states)
443
+ key_states = self.k_layernorm(key_states)
444
+
445
+ # Flash attention requires the input to have the shape
446
+ # batch_size x seq_length x head_dim x hidden_dim
447
+ # therefore we just need to keep the original shape
448
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
449
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
450
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
451
+
452
+ kv_seq_len = key_states.shape[-2]
453
+ if past_key_value is not None:
454
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
455
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
456
+
457
+ # Partial rotary embedding
458
+ query_rot, query_pass = (
459
+ query_states[..., : self.rotary_emb.dim],
460
+ query_states[..., self.rotary_emb.dim :],
461
+ )
462
+ key_rot, key_pass = (
463
+ key_states[..., : self.rotary_emb.dim],
464
+ key_states[..., self.rotary_emb.dim :],
465
+ )
466
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
467
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
468
+
469
+ # [batch_size, seq_length, num_heads, head_dim]
470
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
471
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
472
+
473
+ if past_key_value is not None:
474
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
475
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
476
+
477
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
478
+ # to be able to avoid many of these transpose/reshape/view.
479
+ query_states = query_states.transpose(1, 2)
480
+ key_states = key_states.transpose(1, 2)
481
+ value_states = value_states.transpose(1, 2)
482
+
483
+ attn_dropout = self.attention_dropout if self.training else 0.0
484
+
485
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
486
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
487
+ # cast them back in the correct dtype just to be sure everything works as expected.
488
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
489
+ # in fp32.
490
+
491
+ if query_states.dtype == torch.float32:
492
+ if torch.is_autocast_enabled():
493
+ target_dtype = torch.get_autocast_gpu_dtype()
494
+ # Handle the case where the model is quantized
495
+ elif hasattr(self.config, "_pre_quantization_dtype"):
496
+ target_dtype = self.config._pre_quantization_dtype
497
+ else:
498
+ target_dtype = self.q_proj.weight.dtype
499
+
500
+ logger.warning_once(
501
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
502
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
503
+ f" {target_dtype}."
504
+ )
505
+
506
+ query_states = query_states.to(target_dtype)
507
+ key_states = key_states.to(target_dtype)
508
+ value_states = value_states.to(target_dtype)
509
+
510
+ attn_output = self._flash_attention_forward(
511
+ query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
512
+ )
513
+
514
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
515
+ attn_output = self.dense(attn_output)
516
+
517
+ if not output_attentions:
518
+ attn_weights = None
519
+
520
+ return attn_output, attn_weights, past_key_value
521
+
522
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
523
+ def _flash_attention_forward(
524
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
525
+ ):
526
+ """
527
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
528
+ first unpad the input, then computes the attention scores and pad the final attention scores.
529
+ Args:
530
+ query_states (`torch.Tensor`):
531
+ Input query states to be passed to Flash Attention API
532
+ key_states (`torch.Tensor`):
533
+ Input key states to be passed to Flash Attention API
534
+ value_states (`torch.Tensor`):
535
+ Input value states to be passed to Flash Attention API
536
+ attention_mask (`torch.Tensor`):
537
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
538
+ position of padding tokens and 1 for the position of non-padding tokens.
539
+ dropout (`int`, *optional*):
540
+ Attention dropout
541
+ softmax_scale (`float`, *optional*):
542
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
543
+ """
544
+ if not self._flash_attn_uses_top_left_mask:
545
+ causal = self.is_causal
546
+ else:
547
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
548
+ causal = self.is_causal and query_length != 1
549
+
550
+ # Contains at least one padding token in the sequence
551
+ if attention_mask is not None:
552
+ batch_size = query_states.shape[0]
553
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
554
+ query_states, key_states, value_states, attention_mask, query_length
555
+ )
556
+
557
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
558
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
559
+
560
+ attn_output_unpad = flash_attn_varlen_func(
561
+ query_states,
562
+ key_states,
563
+ value_states,
564
+ cu_seqlens_q=cu_seqlens_q,
565
+ cu_seqlens_k=cu_seqlens_k,
566
+ max_seqlen_q=max_seqlen_in_batch_q,
567
+ max_seqlen_k=max_seqlen_in_batch_k,
568
+ dropout_p=dropout,
569
+ softmax_scale=softmax_scale,
570
+ causal=causal,
571
+ )
572
+
573
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
574
+ else:
575
+ attn_output = flash_attn_func(
576
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
577
+ )
578
+
579
+ return attn_output
580
+
581
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
582
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
583
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
584
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
585
+
586
+ key_layer = index_first_axis(
587
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
588
+ )
589
+ value_layer = index_first_axis(
590
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
591
+ )
592
+ if query_length == kv_seq_len:
593
+ query_layer = index_first_axis(
594
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
595
+ )
596
+ cu_seqlens_q = cu_seqlens_k
597
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
598
+ indices_q = indices_k
599
+ elif query_length == 1:
600
+ max_seqlen_in_batch_q = 1
601
+ cu_seqlens_q = torch.arange(
602
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
603
+ ) # There is a memcpy here, that is very bad.
604
+ indices_q = cu_seqlens_q[:-1]
605
+ query_layer = query_layer.squeeze(1)
606
+ else:
607
+ # The -q_len: slice assumes left padding.
608
+ attention_mask = attention_mask[:, -query_length:]
609
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
610
+
611
+ return (
612
+ query_layer,
613
+ key_layer,
614
+ value_layer,
615
+ indices_q,
616
+ (cu_seqlens_q, cu_seqlens_k),
617
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
618
+ )
619
+
620
+
621
+ PHI_ATTENTION_CLASSES = {
622
+ "eager": PhiAttention,
623
+ "flash_attention_2": PhiFlashAttention2,
624
+ }
625
+
626
+
627
+ class PhiDecoderLayer(nn.Module):
628
+ def __init__(self, config: PhiConfig, layer_idx: int):
629
+ super().__init__()
630
+ self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
631
+ self.mlp = PhiMLP(config)
632
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
633
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
634
+
635
+ def forward(
636
+ self,
637
+ hidden_states: torch.Tensor,
638
+ attention_mask: Optional[torch.Tensor] = None,
639
+ position_ids: Optional[torch.LongTensor] = None,
640
+ output_attentions: Optional[bool] = False,
641
+ use_cache: Optional[bool] = False,
642
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
643
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
644
+ """
645
+ Args:
646
+ hidden_states (`torch.FloatTensor`):
647
+ input to the layer of shape `(batch, seq_len, embed_dim)`
648
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
649
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
650
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
651
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
652
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
653
+ output_attentions (`bool`, *optional*):
654
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
655
+ returned tensors for more detail.
656
+ use_cache (`bool`, *optional*):
657
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
658
+ (see `past_key_values`).
659
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
660
+ """
661
+
662
+ residual = hidden_states
663
+
664
+ hidden_states = self.input_layernorm(hidden_states)
665
+
666
+ # Self Attention
667
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
668
+ hidden_states=hidden_states,
669
+ attention_mask=attention_mask,
670
+ position_ids=position_ids,
671
+ past_key_value=past_key_value,
672
+ output_attentions=output_attentions,
673
+ use_cache=use_cache,
674
+ )
675
+ attn_outputs = self.resid_dropout(attn_outputs)
676
+
677
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
678
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
679
+ outputs = (hidden_states,)
680
+
681
+ if output_attentions:
682
+ outputs += (self_attn_weights,)
683
+
684
+ if use_cache:
685
+ outputs += (present_key_value,)
686
+
687
+ return outputs
688
+
689
+
690
+ PHI_START_DOCSTRING = r"""
691
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
692
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
693
+ etc.)
694
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
695
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
696
+ and behavior.
697
+ Parameters:
698
+ config ([`PhiConfig`]):
699
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
700
+ load the weights associated with the model, only the configuration. Check out the
701
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
702
+ """
703
+
704
+
705
+ @add_start_docstrings(
706
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
707
+ PHI_START_DOCSTRING,
708
+ )
709
+ class PhiPreTrainedModel(PreTrainedModel):
710
+ config_class = PhiConfig
711
+ base_model_prefix = "model"
712
+ supports_gradient_checkpointing = True
713
+ _no_split_modules = ["PhiDecoderLayer"]
714
+ _skip_keys_device_placement = "past_key_values"
715
+ _supports_flash_attn_2 = True
716
+ _supports_cache_class = True
717
+
718
+ def _init_weights(self, module):
719
+ std = self.config.initializer_range
720
+ if isinstance(module, nn.Linear):
721
+ module.weight.data.normal_(mean=0.0, std=std)
722
+ if module.bias is not None:
723
+ module.bias.data.zero_()
724
+ elif isinstance(module, nn.Embedding):
725
+ module.weight.data.normal_(mean=0.0, std=std)
726
+ if module.padding_idx is not None:
727
+ module.weight.data[module.padding_idx].zero_()
728
+
729
+
730
+ PHI_INPUTS_DOCSTRING = r"""
731
+ Args:
732
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
733
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
734
+ it.
735
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
736
+ [`PreTrainedTokenizer.__call__`] for details.
737
+ [What are input IDs?](../glossary#input-ids)
738
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
739
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
740
+ - 1 for tokens that are **not masked**,
741
+ - 0 for tokens that are **masked**.
742
+ [What are attention masks?](../glossary#attention-mask)
743
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
744
+ [`PreTrainedTokenizer.__call__`] for details.
745
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
746
+ `past_key_values`).
747
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
748
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
749
+ information on the default strategy.
750
+ - 1 indicates the head is **not masked**,
751
+ - 0 indicates the head is **masked**.
752
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
753
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
754
+ config.n_positions - 1]`.
755
+ [What are position IDs?](../glossary#position-ids)
756
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
757
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
758
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
759
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
760
+ Two formats are allowed:
761
+ - a [`~cache_utils.Cache`] instance;
762
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
763
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
764
+ cache format.
765
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
766
+ legacy cache format will be returned.
767
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
768
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
769
+ of shape `(batch_size, sequence_length)`.
770
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
771
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
772
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
773
+ model's internal embedding lookup matrix.
774
+ use_cache (`bool`, *optional*):
775
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
776
+ `past_key_values`).
777
+ output_attentions (`bool`, *optional*):
778
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
779
+ tensors for more detail.
780
+ output_hidden_states (`bool`, *optional*):
781
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
782
+ more detail.
783
+ return_dict (`bool`, *optional*):
784
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
785
+ """
786
+
787
+
788
+ @add_start_docstrings(
789
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
790
+ PHI_START_DOCSTRING,
791
+ )
792
+ class PhiModel(PhiPreTrainedModel):
793
+ """
794
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
795
+ Args:
796
+ config: PhiConfig
797
+ """
798
+
799
+ def __init__(self, config: PhiConfig):
800
+ super().__init__(config)
801
+ self.padding_idx = config.pad_token_id
802
+ self.vocab_size = config.vocab_size
803
+
804
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
805
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
806
+ self.layers = nn.ModuleList(
807
+ [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
808
+ )
809
+ self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
810
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
811
+
812
+ self.gradient_checkpointing = False
813
+ # Initialize weights and apply final processing
814
+ self.post_init()
815
+
816
+ def get_input_embeddings(self):
817
+ return self.embed_tokens
818
+
819
+ def set_input_embeddings(self, value):
820
+ self.embed_tokens = value
821
+
822
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
823
+ def forward(
824
+ self,
825
+ input_ids: torch.LongTensor = None,
826
+ attention_mask: Optional[torch.Tensor] = None,
827
+ position_ids: Optional[torch.LongTensor] = None,
828
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
829
+ inputs_embeds: Optional[torch.FloatTensor] = None,
830
+ use_cache: Optional[bool] = None,
831
+ output_attentions: Optional[bool] = None,
832
+ output_hidden_states: Optional[bool] = None,
833
+ return_dict: Optional[bool] = None,
834
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
835
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
836
+ output_hidden_states = (
837
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
838
+ )
839
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
840
+
841
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
842
+
843
+ # retrieve input_ids and inputs_embeds
844
+ if input_ids is not None and inputs_embeds is not None:
845
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
846
+ elif input_ids is not None:
847
+ batch_size, seq_length = input_ids.shape[:2]
848
+ elif inputs_embeds is not None:
849
+ batch_size, seq_length = inputs_embeds.shape[:2]
850
+ else:
851
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
852
+
853
+ past_key_values_length = 0
854
+
855
+ if self.gradient_checkpointing and self.training:
856
+ if use_cache:
857
+ logger.warning_once(
858
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
859
+ )
860
+ use_cache = False
861
+
862
+ if use_cache:
863
+ use_legacy_cache = not isinstance(past_key_values, Cache)
864
+ if use_legacy_cache:
865
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
866
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
867
+
868
+ if position_ids is None:
869
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
870
+ position_ids = torch.arange(
871
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
872
+ )
873
+ position_ids = position_ids.unsqueeze(0)
874
+
875
+ if inputs_embeds is None:
876
+ inputs_embeds = self.embed_tokens(input_ids)
877
+
878
+ inputs_embeds = self.embed_dropout(inputs_embeds)
879
+
880
+ # Attention mask.
881
+ if self._use_flash_attention_2:
882
+ # 2d mask is passed through the layers
883
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
884
+ else:
885
+ # 4d mask is passed through the layers
886
+ attention_mask = _prepare_4d_causal_attention_mask(
887
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
888
+ )
889
+
890
+ hidden_states = inputs_embeds
891
+
892
+ # decoder layers
893
+ all_hidden_states = () if output_hidden_states else None
894
+ all_self_attns = () if output_attentions else None
895
+ next_decoder_cache = None
896
+
897
+ for decoder_layer in self.layers:
898
+ if output_hidden_states:
899
+ all_hidden_states += (hidden_states,)
900
+
901
+ if self.gradient_checkpointing and self.training:
902
+ layer_outputs = self._gradient_checkpointing_func(
903
+ decoder_layer.__call__,
904
+ hidden_states,
905
+ attention_mask,
906
+ position_ids,
907
+ past_key_values,
908
+ output_attentions,
909
+ )
910
+ else:
911
+ layer_outputs = decoder_layer(
912
+ hidden_states,
913
+ attention_mask=attention_mask,
914
+ position_ids=position_ids,
915
+ past_key_value=past_key_values,
916
+ output_attentions=output_attentions,
917
+ use_cache=use_cache,
918
+ )
919
+
920
+ hidden_states = layer_outputs[0]
921
+
922
+ if use_cache:
923
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
924
+
925
+ if output_attentions:
926
+ all_self_attns += (layer_outputs[1],)
927
+
928
+ hidden_states = self.final_layernorm(hidden_states)
929
+
930
+ # add hidden states from the last decoder layer
931
+ if output_hidden_states:
932
+ all_hidden_states += (hidden_states,)
933
+
934
+ next_cache = None
935
+ if use_cache:
936
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
937
+ if not return_dict:
938
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
939
+ return BaseModelOutputWithPast(
940
+ last_hidden_state=hidden_states,
941
+ past_key_values=next_cache,
942
+ hidden_states=all_hidden_states,
943
+ attentions=all_self_attns,
944
+ )
945
+
946
+
947
+ class PhiForCausalLM(PhiPreTrainedModel):
948
+ _tied_weights_keys = ["lm_head.weight"]
949
+
950
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
951
+ def __init__(self, config):
952
+ super().__init__(config)
953
+ self.model = PhiModel(config)
954
+ self.vocab_size = config.vocab_size
955
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
956
+
957
+ # Initialize weights and apply final processing
958
+ self.post_init()
959
+
960
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
961
+ def get_input_embeddings(self):
962
+ return self.model.embed_tokens
963
+
964
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
965
+ def set_input_embeddings(self, value):
966
+ self.model.embed_tokens = value
967
+
968
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
969
+ def get_output_embeddings(self):
970
+ return self.lm_head
971
+
972
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
973
+ def set_output_embeddings(self, new_embeddings):
974
+ self.lm_head = new_embeddings
975
+
976
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
977
+ def set_decoder(self, decoder):
978
+ self.model = decoder
979
+
980
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
981
+ def get_decoder(self):
982
+ return self.model
983
+
984
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
985
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
986
+ def forward(
987
+ self,
988
+ input_ids: torch.LongTensor = None,
989
+ attention_mask: Optional[torch.Tensor] = None,
990
+ position_ids: Optional[torch.LongTensor] = None,
991
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
992
+ inputs_embeds: Optional[torch.FloatTensor] = None,
993
+ labels: Optional[torch.LongTensor] = None,
994
+ use_cache: Optional[bool] = None,
995
+ output_attentions: Optional[bool] = None,
996
+ output_hidden_states: Optional[bool] = None,
997
+ return_dict: Optional[bool] = None,
998
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
999
+ r"""
1000
+ Args:
1001
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1002
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1003
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1004
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1005
+ Returns:
1006
+ Example:
1007
+ ```python
1008
+ >>> from transformers import AutoTokenizer, PhiForCausalLM
1009
+ >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
1010
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
1011
+ >>> prompt = "This is an example script ."
1012
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1013
+ >>> # Generate
1014
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1015
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1016
+ 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1017
+ ```"""
1018
+
1019
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1020
+ output_hidden_states = (
1021
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1022
+ )
1023
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1024
+
1025
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1026
+ outputs = self.model(
1027
+ input_ids=input_ids,
1028
+ attention_mask=attention_mask,
1029
+ position_ids=position_ids,
1030
+ past_key_values=past_key_values,
1031
+ inputs_embeds=inputs_embeds,
1032
+ use_cache=use_cache,
1033
+ output_attentions=output_attentions,
1034
+ output_hidden_states=output_hidden_states,
1035
+ return_dict=return_dict,
1036
+ )
1037
+
1038
+ hidden_states = outputs[0]
1039
+ logits = self.lm_head(hidden_states)
1040
+ logits = logits.float()
1041
+
1042
+ loss = None
1043
+ if labels is not None:
1044
+ # Shift so that tokens < n predict n
1045
+ shift_logits = logits[..., :-1, :].contiguous()
1046
+ shift_labels = labels[..., 1:].contiguous()
1047
+ # Flatten the tokens
1048
+ loss_fct = CrossEntropyLoss()
1049
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1050
+ shift_labels = shift_labels.view(-1)
1051
+ # Enable model parallelism
1052
+ shift_labels = shift_labels.to(shift_logits.device)
1053
+ loss = loss_fct(shift_logits, shift_labels)
1054
+
1055
+ if not return_dict:
1056
+ output = (logits,) + outputs[1:]
1057
+ return (loss,) + output if loss is not None else output
1058
+
1059
+ return CausalLMOutputWithPast(
1060
+ loss=loss,
1061
+ logits=logits,
1062
+ past_key_values=outputs.past_key_values,
1063
+ hidden_states=outputs.hidden_states,
1064
+ attentions=outputs.attentions,
1065
+ )
1066
+
1067
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1068
+ def prepare_inputs_for_generation(
1069
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1070
+ ):
1071
+ if past_key_values is not None:
1072
+ if isinstance(past_key_values, Cache):
1073
+ cache_length = past_key_values.get_seq_length()
1074
+ past_length = past_key_values.seen_tokens
1075
+ max_cache_length = past_key_values.get_max_length()
1076
+ else:
1077
+ cache_length = past_length = past_key_values[0][0].shape[2]
1078
+ max_cache_length = None
1079
+
1080
+ # Keep only the unprocessed tokens:
1081
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1082
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1083
+ # input)
1084
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1085
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1086
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1087
+ # input_ids based on the past_length.
1088
+ elif past_length < input_ids.shape[1]:
1089
+ input_ids = input_ids[:, past_length:]
1090
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1091
+
1092
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1093
+ if (
1094
+ max_cache_length is not None
1095
+ and attention_mask is not None
1096
+ and cache_length + input_ids.shape[1] > max_cache_length
1097
+ ):
1098
+ attention_mask = attention_mask[:, -max_cache_length:]
1099
+
1100
+ position_ids = kwargs.get("position_ids", None)
1101
+ if attention_mask is not None and position_ids is None:
1102
+ # create position_ids on the fly for batch generation
1103
+ position_ids = attention_mask.long().cumsum(-1) - 1
1104
+ position_ids.masked_fill_(attention_mask == 0, 1)
1105
+ if past_key_values:
1106
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1107
+
1108
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1109
+ if inputs_embeds is not None and past_key_values is None:
1110
+ model_inputs = {"inputs_embeds": inputs_embeds}
1111
+ else:
1112
+ model_inputs = {"input_ids": input_ids}
1113
+
1114
+ model_inputs.update(
1115
+ {
1116
+ "position_ids": position_ids,
1117
+ "past_key_values": past_key_values,
1118
+ "use_cache": kwargs.get("use_cache"),
1119
+ "attention_mask": attention_mask,
1120
+ }
1121
+ )
1122
+ return model_inputs
1123
+
1124
+ @staticmethod
1125
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1126
+ def _reorder_cache(past_key_values, beam_idx):
1127
+ reordered_past = ()
1128
+ for layer_past in past_key_values:
1129
+ reordered_past += (
1130
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1131
+ )
1132
+ return reordered_past
1133
+
1134
+
1135
+ @add_start_docstrings(
1136
+ """
1137
+ The PhiModel with a sequence classification head on top (linear layer).
1138
+ [`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1139
+ (e.g. GPT-2) do.
1140
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1141
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1142
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1143
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1144
+ each row of the batch).
1145
+ """,
1146
+ PHI_START_DOCSTRING,
1147
+ )
1148
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs
1149
+ class PhiForSequenceClassification(PhiPreTrainedModel):
1150
+ def __init__(self, config):
1151
+ super().__init__(config)
1152
+ self.num_labels = config.num_labels
1153
+ self.model = PhiModel(config)
1154
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1155
+
1156
+ # Initialize weights and apply final processing
1157
+ self.post_init()
1158
+
1159
+ def get_input_embeddings(self):
1160
+ return self.model.embed_tokens
1161
+
1162
+ def set_input_embeddings(self, value):
1163
+ self.model.embed_tokens = value
1164
+
1165
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1166
+ def forward(
1167
+ self,
1168
+ input_ids: torch.LongTensor = None,
1169
+ attention_mask: Optional[torch.Tensor] = None,
1170
+ position_ids: Optional[torch.LongTensor] = None,
1171
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1172
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1173
+ labels: Optional[torch.LongTensor] = None,
1174
+ use_cache: Optional[bool] = None,
1175
+ output_attentions: Optional[bool] = None,
1176
+ output_hidden_states: Optional[bool] = None,
1177
+ return_dict: Optional[bool] = None,
1178
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1179
+ r"""
1180
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1181
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1182
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1183
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1184
+ """
1185
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1186
+
1187
+ model_outputs = self.model(
1188
+ input_ids,
1189
+ attention_mask=attention_mask,
1190
+ position_ids=position_ids,
1191
+ past_key_values=past_key_values,
1192
+ inputs_embeds=inputs_embeds,
1193
+ use_cache=use_cache,
1194
+ output_attentions=output_attentions,
1195
+ output_hidden_states=output_hidden_states,
1196
+ return_dict=return_dict,
1197
+ )
1198
+ hidden_states = model_outputs[0]
1199
+ logits = self.score(hidden_states)
1200
+
1201
+ if input_ids is not None:
1202
+ batch_size = input_ids.shape[0]
1203
+ else:
1204
+ batch_size = inputs_embeds.shape[0]
1205
+
1206
+ if self.config.pad_token_id is None and batch_size != 1:
1207
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1208
+ if self.config.pad_token_id is None:
1209
+ sequence_lengths = -1
1210
+ else:
1211
+ if input_ids is not None:
1212
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1213
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1214
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1215
+ sequence_lengths = sequence_lengths.to(logits.device)
1216
+ else:
1217
+ sequence_lengths = -1
1218
+
1219
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1220
+
1221
+ loss = None
1222
+ if labels is not None:
1223
+ labels = labels.to(logits.device)
1224
+ if self.config.problem_type is None:
1225
+ if self.num_labels == 1:
1226
+ self.config.problem_type = "regression"
1227
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1228
+ self.config.problem_type = "single_label_classification"
1229
+ else:
1230
+ self.config.problem_type = "multi_label_classification"
1231
+
1232
+ if self.config.problem_type == "regression":
1233
+ loss_fct = MSELoss()
1234
+ if self.num_labels == 1:
1235
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1236
+ else:
1237
+ loss = loss_fct(pooled_logits, labels)
1238
+ elif self.config.problem_type == "single_label_classification":
1239
+ loss_fct = CrossEntropyLoss()
1240
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1241
+ elif self.config.problem_type == "multi_label_classification":
1242
+ loss_fct = BCEWithLogitsLoss()
1243
+ loss = loss_fct(pooled_logits, labels)
1244
+ if not return_dict:
1245
+ output = (pooled_logits,) + model_outputs[1:]
1246
+ return ((loss,) + output) if loss is not None else output
1247
+
1248
+ return SequenceClassifierOutputWithPast(
1249
+ loss=loss,
1250
+ logits=pooled_logits,
1251
+ past_key_values=model_outputs.past_key_values,
1252
+ hidden_states=model_outputs.hidden_states,
1253
+ attentions=model_outputs.attentions,
1254
+ )
1255
+
1256
+
1257
+ @add_start_docstrings(
1258
+ """
1259
+ PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1260
+ Named-Entity-Recognition (NER) tasks.
1261
+ """,
1262
+ PHI_START_DOCSTRING,
1263
+ )
1264
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi,self.transformer->self.model,transformer_outputs->model_outputs
1265
+ class PhiForTokenClassification(PhiPreTrainedModel):
1266
+ def __init__(self, config: PhiConfig):
1267
+ super().__init__(config)
1268
+ self.num_labels = config.num_labels
1269
+
1270
+ self.model = PhiModel(config)
1271
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1272
+ classifier_dropout = config.classifier_dropout
1273
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1274
+ classifier_dropout = config.hidden_dropout
1275
+ else:
1276
+ classifier_dropout = 0.1
1277
+ self.dropout = nn.Dropout(classifier_dropout)
1278
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1279
+
1280
+ # Initialize weights and apply final processing
1281
+ self.post_init()
1282
+
1283
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1284
+ @add_code_sample_docstrings(
1285
+ checkpoint=_CHECKPOINT_FOR_DOC,
1286
+ output_type=TokenClassifierOutput,
1287
+ config_class=_CONFIG_FOR_DOC,
1288
+ )
1289
+ def forward(
1290
+ self,
1291
+ input_ids: Optional[torch.LongTensor] = None,
1292
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1293
+ attention_mask: Optional[torch.Tensor] = None,
1294
+ inputs_embeds: Optional[torch.Tensor] = None,
1295
+ labels: Optional[torch.Tensor] = None,
1296
+ use_cache: Optional[bool] = None,
1297
+ output_attentions: Optional[bool] = None,
1298
+ output_hidden_states: Optional[bool] = None,
1299
+ return_dict: Optional[bool] = None,
1300
+ **deprecated_arguments,
1301
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1302
+ r"""
1303
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1304
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1305
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1306
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1307
+ """
1308
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1309
+
1310
+ model_outputs = self.model(
1311
+ input_ids,
1312
+ past_key_values=past_key_values,
1313
+ attention_mask=attention_mask,
1314
+ inputs_embeds=inputs_embeds,
1315
+ use_cache=use_cache,
1316
+ output_attentions=output_attentions,
1317
+ output_hidden_states=output_hidden_states,
1318
+ return_dict=return_dict,
1319
+ )
1320
+
1321
+ hidden_states = model_outputs[0]
1322
+ hidden_states = self.dropout(hidden_states)
1323
+ logits = self.classifier(hidden_states)
1324
+
1325
+ loss = None
1326
+ if labels is not None:
1327
+ # move labels to correct device to enable model parallelism
1328
+ labels = labels.to(logits.device)
1329
+ batch_size, seq_length = labels.shape
1330
+ loss_fct = CrossEntropyLoss()
1331
+ loss = loss_fct(
1332
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1333
+ )
1334
+
1335
+ if not return_dict:
1336
+ output = (logits,) + model_outputs[2:]
1337
+ return ((loss,) + output) if loss is not None else output
1338
+
1339
+ return TokenClassifierOutput(
1340
+ loss=loss,
1341
+ logits=logits,
1342
+ hidden_states=model_outputs.hidden_states,
1343
+ attentions=model_outputs.attentions,
1344
+ )