Spaces:
Runtime error
Runtime error
| #%% | |
| import yaml | |
| import torch as t | |
| import gradio as gr | |
| import re | |
| from word_data import WordData | |
| import sampling | |
| import transformer_replication | |
| #%% | |
| MAIN = __name__ == '__main__' | |
| device = 'cuda' if t.cuda.is_available() else 'cpu' | |
| #%% | |
| shakespeare = WordData.from_file( | |
| '100-0.txt', device=device, start="1\n", end='ALL’S WELL THAT ENDS WELL' | |
| ) | |
| if MAIN: | |
| print('Vocab size: ', len(shakespeare.vocab)) | |
| #%% | |
| #%% | |
| with open('config.yaml', 'r') as f: | |
| yaml_cfg = yaml.safe_load(f) | |
| #%% | |
| with open('model_state_dict.pt') as f: | |
| state_dict = t.load( | |
| 'model_state_dict.pt', | |
| map_location=device, | |
| ) | |
| #%% | |
| base_config = transformer_replication.TransformerConfig( | |
| num_layers=yaml_cfg['num_layers']['value'], | |
| num_heads=yaml_cfg['num_heads']['value'], | |
| vocab_size=len(shakespeare.vocab), | |
| hidden_size=yaml_cfg['hidden_size']['value'], | |
| max_seq_len=yaml_cfg['max_seq_len']['value'], | |
| dropout=yaml_cfg['dropout']['value'], | |
| ) | |
| shakespeare.model_max_length = yaml_cfg['max_seq_len']['value'] | |
| model = transformer_replication.DecoderOnlyTransformer(base_config) | |
| model.load_state_dict(state_dict) | |
| #%% | |
| def generate( | |
| text: str, max_tokens: int, temperature: float, | |
| top_k: int, | |
| ) -> str: | |
| return sampling.sample_tokens( | |
| model, | |
| shakespeare, | |
| text, | |
| max_tokens_generated=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| ) | |
| #%% | |
| def safe_generate( | |
| text: str, max_tokens: int = 300, temperature: float = 1.0, | |
| top_k: int = 20, | |
| ) -> str: | |
| try: | |
| raw = generate( | |
| text, max_tokens=max_tokens, temperature=temperature, top_k=top_k, | |
| ) | |
| match = re.match(r"(?P<start>\D*)\d+\n", raw) | |
| if match is None: | |
| return raw | |
| return match.group('start') | |
| except KeyError as e: | |
| return f"I'm sorry, {str(e)} is not in Shakespeare's vocabulary" | |
| #%% | |
| examples = [ | |
| ["I sang a beautiful song"], | |
| ["To be free is to"], | |
| ["How I love thee"], | |
| ] | |
| #%% | |
| if MAIN: | |
| print(safe_generate('How I love thee')) | |
| #%% | |
| description = """ | |
| Provide a prompt in the "Input Text" window below and then click "Submit". | |
| The small Shakespeare transformer model trained on my laptop will attempt to | |
| complete the Sonnet that you started. | |
| Thanks to Project Gutenberg for providing the training corpus. | |
| """ | |
| #%% | |
| def make_demo(): | |
| demo = gr.Interface( | |
| fn=safe_generate, | |
| inputs=[ | |
| gr.components.Textbox(lines=5, label="Input Text"), | |
| gr.components.Slider( | |
| label='max tokens generated', minimum=1, maximum=1000, | |
| value=300, step=1, | |
| ), | |
| gr.components.Slider( | |
| label='temperature', minimum=0, maximum=2, value=1, step=0.1, | |
| ), | |
| gr.components.Slider( | |
| label='top_k', minimum=1, maximum=100, value=10, step=1, | |
| ), | |
| ], | |
| outputs=gr.components.Textbox(label="Generated Text"), | |
| examples=examples, | |
| title='Shakespeare transformer sampling', | |
| description=description, | |
| ) | |
| demo.launch() | |
| #%% | |