from transformers import pipeline, set_seed import re from transformers import set_seed model_ckpt = './' generation = pipeline('text-generation', model=model_ckpt, device=0) def first_block(string): return re.split('\nclass|\ndef|\n#|\n@|\nprint|\nif', string)[0].rstrip() def complete_code(pipe, prompt, max_length=64, num_completions=4, seed=1): set_seed(seed) gen_kwargs = {"temperature":0.4, "top_p":0.95, "top_k":0, "num_beams":1, "do_sample":True,} code_gens = generation(prompt, num_return_sequences=num_completions, max_length=max_length, **gen_kwargs) code_strings = [] for code_gen in code_gens: generated_code = first_block(code_gen['generated_text'][len(prompt):]) code_strings.append(generated_code) print(('\n'+'='*80 + '\n').join(code_strings)) prompt = '''def area_of_rectangle(a: float, b: float): """Return the area of the rectangle."""''' complete_code(generation, prompt)