| import torch |
| import tiktoken |
| from model import GPTConfig, GPT |
|
|
| |
| ckpt_path = '/home/user/350m_SmaLLMPro_Final/SmaLLMPro_iter_1500.pt' |
| device = 'cuda' |
| enc = tiktoken.get_encoding("gpt2") |
|
|
| print("Loading SmaLLMPro...") |
| checkpoint = torch.load(ckpt_path, map_location=device) |
| gptconf = GPTConfig(**checkpoint['model_args']) |
| model = GPT(gptconf) |
|
|
| state_dict = checkpoint['model'] |
| unwanted_prefix = '_orig_mod.' |
| for k,v in list(state_dict.items()): |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
|
| model.load_state_dict(state_dict) |
| model.eval() |
| model.to(device) |
| print(f"Model {ckpt_path} ready!\n") |
|
|
| def run_chat(): |
| print("--- SmaLLMPro Chatbot (Type 'exit' to quit) ---") |
| |
| while True: |
| user_input = input("You: ") |
| if user_input.lower() in ["exit", "quit"]: |
| break |
|
|
| prompt = f"Instruction:\n{user_input}\n\nResponse:\n" |
| |
| x = torch.tensor(enc.encode(prompt), dtype=torch.long, device=device)[None, ...] |
| |
| print("SmaLLMPro: ", end="", flush=True) |
| with torch.no_grad(): |
| with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16): |
| y = model.generate(x, max_new_tokens=500, temperature=0.65, top_k=25) |
| |
| full_text = enc.decode(y[0].tolist()) |
| |
| if "Response:\n" in full_text: |
| response = full_text.split("Response:\n")[-1] |
| else: |
| response = full_text |
| |
| response = response.split("<|endoftext|>")[0].split("Instruction:")[0].strip() |
| print(response + "\n") |
|
|
| if __name__ == "__main__": |
| run_chat() |