Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import transformers | |
| def reduce_sum(value, mask, axis=None): | |
| if axis is None: | |
| return torch.sum(value * mask) | |
| return torch.sum(value * mask, axis) | |
| def reduce_mean(value, mask, axis=None): | |
| if axis is None: | |
| return torch.sum(value * mask) / torch.sum(mask) | |
| return reduce_sum(value, mask, axis) / torch.sum(mask, axis) | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| class InteractiveRainier: | |
| def __init__(self): | |
| self.tokenizer = transformers.AutoTokenizer.from_pretrained('allenai/unifiedqa-t5-large') | |
| self.rainier_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/rainier-large').to(device) | |
| self.qa_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('allenai/unifiedqa-t5-large').to(device) | |
| self.loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100,reduction='none') | |
| def parse_choices(self, s): | |
| ''' | |
| s: serialized_choices '(A) ... (B) ... (C) ...' | |
| ''' | |
| choices = [] | |
| key = 'A' if s.find('(A)') != -1 else 'a' | |
| while True: | |
| pos = s.find(f'({chr(ord(key) + 1)})') | |
| if pos == -1: | |
| break | |
| choice = s[3:pos] | |
| s = s[pos:] | |
| choice = choice.strip(' ') | |
| choices.append(choice) | |
| key = chr(ord(key) + 1) | |
| choice = s[3:] | |
| choice = choice.strip(' ') | |
| choices.append(choice) | |
| return choices | |
| def run(self, question, max_input_len, max_output_len, m, top_p): | |
| tokenized = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L) | |
| knowledges_ids = self.rainier_model.generate( | |
| input_ids=tokenized.input_ids, | |
| max_length=max_output_len + 1, | |
| min_length=3, | |
| do_sample=True, | |
| num_return_sequences=m, | |
| top_p=top_p, | |
| ) # (K, L); begins with 0 ([BOS]); ends with 1 ([EOS]) | |
| knowledges_ids = knowledges_ids[:, 1:].contiguous() # no beginning; ends with 1 ([EOS]) | |
| knowledges = self.tokenizer.batch_decode(knowledges_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| knowledges = list(set(knowledges)) | |
| knowledges = [''] + knowledges | |
| prompts = [question + (f' \\n {knowledge}' if knowledge != '' else '') for knowledge in knowledges] | |
| choices = self.parse_choices(question.split('\\n')[1].strip(' ')) | |
| prompts = [prompt.lower() for prompt in prompts] | |
| choices = [choice.lower() for choice in choices] | |
| answer_logitss = [] | |
| for choice in choices: | |
| tokenized_prompts = self.tokenizer(prompts, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1+K, L) | |
| tokenized_choices = self.tokenizer([choice], return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L) | |
| pad_mask = (tokenized_choices.input_ids == self.tokenizer.pad_token_id) | |
| tokenized_choices.input_ids[pad_mask] = -100 | |
| tokenized_choices.input_ids = tokenized_choices.input_ids.repeat(len(knowledges), 1) # (1+K, L) | |
| with torch.no_grad(): | |
| logits = self.qa_model( | |
| input_ids=tokenized_prompts.input_ids, | |
| attention_mask=tokenized_prompts.attention_mask, | |
| labels=tokenized_choices.input_ids, | |
| ).logits # (1+K, L, V) | |
| losses = self.loss_fct(logits.view(-1, logits.size(-1)), tokenized_choices.input_ids.view(-1)) | |
| losses = losses.view(tokenized_choices.input_ids.shape) # (1+K, L) | |
| losses = reduce_mean(losses, ~pad_mask, axis=-1) # (1+K) | |
| answer_logitss.append(-losses) | |
| answer_logitss = torch.stack(answer_logitss, dim=1) # (1+K, C) | |
| answer_probss = answer_logitss.softmax(dim=1) # (1+K, C) | |
| # Ensemble | |
| knowless_pred = answer_probss[0, :].argmax(dim=0).item() | |
| knowless_pred = choices[knowless_pred] | |
| answer_probs = answer_probss.max(dim=0).values # (C) | |
| knowful_pred = answer_probs.argmax(dim=0).item() | |
| knowful_pred = choices[knowful_pred] | |
| selected_knowledge_ix = answer_probss.max(dim=1).values.argmax(dim=0).item() | |
| selected_knowledge = knowledges[selected_knowledge_ix] | |
| return { | |
| 'question': question, | |
| 'knowledges': knowledges, | |
| 'knowless_pred': knowless_pred, | |
| 'knowful_pred': knowful_pred, | |
| 'selected_knowledge': selected_knowledge, | |
| } | |
| rainier = InteractiveRainier() | |
| def predict(question, kg_model, qa_model, max_input_len, max_output_len, m, top_p): | |
| result = rainier.run(question, max_input_len, max_output_len, m, top_p) | |
| # output = '' | |
| # output += f'QA model answer without knowledge: {result["knowless_pred"]}\n' | |
| # output += f'QA model answer with knowledge: {result["knowful_pred"]}\n' | |
| # output += '\n' | |
| # output += f'All generated knowledges:\n' | |
| # for knowledge in result['knowledges']: | |
| # output += f' {knowledge}\n' | |
| # output += '\n' | |
| # output += f'Knowledge selected to make the prediction: {result["selected_knowledge"]}\n' | |
| return result['knowless_pred'], result['knowful_pred'], '\n'.join(result['knowledges']), result['selected_knowledge'] | |
| examples = [ | |
| 'If the mass of an object gets bigger what will happen to the amount of matter contained within it? \\n (A) gets bigger (B) gets smaller', | |
| 'What would vinyl be an odd thing to replace? \\n (A) pants (B) record albums (C) record store (D) cheese (E) wallpaper', | |
| 'Some pelycosaurs gave rise to reptile ancestral to \\n (A) lamphreys (B) angiosperm (C) mammals (D) paramecium (E) animals (F) protozoa (G) arachnids (H) backbones', | |
| 'Sydney rubbed Addison’s head because she had a horrible headache. What will happen to Sydney? \\n (A) drift to sleep (B) receive thanks (C) be reprimanded', | |
| 'Adam always spent all of the free time watching Tv unlike Hunter who volunteered, due to _ being lazy. \\n (A) Adam (B) Hunter', | |
| 'Causes bad breath and frightens blood-suckers \\n (A) tuna (B) iron (C) trash (D) garlic (E) pubs', | |
| ] | |
| input_question = gr.Dropdown(choices=examples, label='Question:', | |
| info='A multiple-choice commonsense question. Please follow the UnifiedQA input format: "{question} \\n (A) ... (B) ... (C) ..."', | |
| ) | |
| input_kg_model = gr.Textbox(label='Knowledge generation model:', value='liujch1998/rainier-large', interactive=False) | |
| input_qa_model = gr.Textbox(label='QA model:', value='allenai/unifiedqa-t5-large', interactive=False) | |
| input_max_input_len = gr.Number(label='Max number of tokens in question:', value=256, precision=0) | |
| input_max_output_len = gr.Number(label='Max number of tokens in knowledge:', value=32, precision=0) | |
| input_m = gr.Slider(label='Number of generated knowledges:', value=10, mininum=1, maximum=20, step=1, | |
| info='The actual number of generated knowledges may be less than this number due to possible duplicates.', | |
| ) | |
| input_top_p = gr.Slider(label='top_p for knowledge generation:', value=0.5, mininum=0.0, maximum=1.0, step=0.05) | |
| output_knowless_answer = gr.Textbox(label='QA model answer without knowledge:', interactive=False) | |
| output_knowful_answer = gr.Textbox(label='QA model answer with knowledge:', interactive=False) | |
| output_all_knowledges = gr.Textbox(label='All generated knowledges:', interactive=False) | |
| output_selected_knowledge = gr.Textbox(label='Knowledge selected to make the prediction:', interactive=False) | |
| description = '''This is a demo for the paper, [*Rainier: Reinforced Knowledge Introspector for Commonsense Question Answering*](https://arxiv.org/pdf/2210.03078.pdf), presented at EMNLP 2022. [[Code](https://github.com/liujch1998/rainier)] [[Model](https://huggingface.co/liujch1998/rainier-large)] This demo is made & maintained by [Jiacheng (Gary) Liu](https://liujch1998.github.io). | |
| Rainier is a knowledge-generating model that enhances the commonsense QA capability of a QA model. To try this model, select an example question, or write your own commonsense question in the suggested format.''' | |
| gr.Interface( | |
| fn=predict, | |
| inputs=[input_question, input_kg_model, input_qa_model, input_max_input_len, input_max_output_len, input_m, input_top_p], | |
| outputs=[output_knowless_answer, output_knowful_answer, output_all_knowledges, output_selected_knowledge], | |
| title="Rainier Demo", | |
| description=description, | |
| ).launch() | |