File size: 5,910 Bytes
7e2e79c
 
 
 
 
 
 
 
 
3a307f4
 
 
7e2e79c
 
 
 
 
 
 
 
 
 
 
40a4f96
 
 
 
 
 
7e2e79c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from fastapi import FastAPI, Query
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
import os
from typing import List, Optional
import yaml
from types import SimpleNamespace


app = FastAPI(title="PVS Step Recommender API", version="1.0.0")

# ------------------------------
# Global state (loaded once)
# ------------------------------
TOKENIZER = None
MODEL = None
TEST_DATASET = None
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

def load_config(path="config.yaml"):
    with open(path, "r") as f:
        cfg = yaml.safe_load(f)
    return SimpleNamespace(**cfg)


def load_model_and_tokenizer(path: str):
    global TOKENIZER, MODEL
    if TOKENIZER is None or MODEL is None:
        TOKENIZER = AutoTokenizer.from_pretrained(path, use_fast=True)
        # device_map="auto" lets HF place layers; dtype="auto" for mixed precision when available
        MODEL = AutoModelForCausalLM.from_pretrained(path, dtype="auto", device_map="auto")
        # Some models have no pad token id; fall back to eos
        if TOKENIZER.pad_token_id is None and TOKENIZER.eos_token_id is not None:
            TOKENIZER.pad_token = TOKENIZER.eos_token
        print("model and tokenizer loaded")
    return TOKENIZER, MODEL


def recommend_top_k_steps(model, tokenizer, prompt: str, top_k: int = 3):
    inputs = tokenizer(prompt, max_length=2048, truncation=True, return_tensors='pt')
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    stop_ids = {tokenizer.eos_token_id}
    for token in ["END"]:
        tok_id = tokenizer.convert_tokens_to_ids(token)
        if tok_id is not None and tok_id != tokenizer.unk_token_id:
            stop_ids.add(tok_id)

    model.eval()
    with torch.no_grad():
        gen = model.generate(
            **inputs,
            do_sample=True,
            num_return_sequences=top_k,
            top_k=50,
            top_p=0.9,
            temperature=0.7,
            pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
            eos_token_id=list(stop_ids),
            output_scores=True,
            return_dict_in_generate=True,
            max_new_tokens=128,
        )

    sequences = gen.sequences
    scores = gen.scores
    prompt_len = inputs["input_ids"].shape[1]

    suggestions_with_logprob = []
    for i in range(sequences.size(0)):
        gen_ids = sequences[i, prompt_len:]
        # Decode for display; keep raw text and also split first line as the command
        gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()

        total_logprob, token_count = 0.0, 0
        for t in range(min(len(scores), gen_ids.numel())):
            token_id = int(gen_ids[t].item())
            if token_id in stop_ids:
                break
            step_logits = scores[t][i]
            step_logprobs = F.log_softmax(step_logits, dim=-1)
            total_logprob += float(step_logprobs[token_id].item())
            token_count += 1

        length_norm_logprob = total_logprob / max(token_count, 1)
        suggestions_with_logprob.append({
            "log_prob": length_norm_logprob,
            "command": gen_text.split("\n")[0]
        })

    suggestions_with_logprob.sort(key=lambda x: x["log_prob"], reverse=True)
    return suggestions_with_logprob


# ------------------------------
# Pydantic models
# ------------------------------
class RecommendResponse(BaseModel):
    prompt: str
    top_k: int
    suggestions: List[dict]


class RecommendRequest(BaseModel):
    sequent: str
    prev_commands: List[str]
    top_k: Optional[int] = 3

# ------------------------------
# Startup: load config, model, and dataset
# ------------------------------
@app.on_event("startup")
def startup_event():
    # Allow overriding via env vars, else use YAML
    config_path = os.environ.get("PVS_API_CONFIG", "pvs_v5.yaml")
    config = load_config(config_path)

    save_path = os.environ.get("PVS_MODEL_PATH", getattr(config, 'save_path', None))
    if not save_path:
        raise RuntimeError("Model path not provided. Set PVS_MODEL_PATH or include save_path in config YAML.")

    load_model_and_tokenizer(save_path)


# ------------------------------
# Routes
# ------------------------------
@app.get("/health")
def health():
    return {"status": "ok", "device": DEVICE}


@app.get("/info")
def info():
    return {
        "model_name": getattr(MODEL.config, 'name_or_path', None),
        "vocab_size": getattr(MODEL.config, 'vocab_size', None),
        "eos_token_id": TOKENIZER.eos_token_id,
        "pad_token_id": TOKENIZER.pad_token_id,
        "device": str(MODEL.device),
    }

@app.post("/recommend", response_model=RecommendResponse)
def recommend(req: RecommendRequest):
    sequent = req.sequent.strip()
    prev_cmds = req.prev_commands or []
    prompt_lines = [f"Current Sequent:\n{sequent}\n"]
    for i, cmd in enumerate(prev_cmds):
        prompt_lines.append(f"Prev Command {i+1}: {cmd if cmd else 'None'}")
    prompt = "\n".join(prompt_lines) + "\nNext Command:\n"
    suggestions = recommend_top_k_steps(MODEL, TOKENIZER, prompt, top_k=req.top_k)
    return RecommendResponse(prompt=prompt, top_k=req.top_k, suggestions=suggestions)

    # if not prompt.strip():
    #     return JSONResponse(status_code=400, content={"error": "prompt must be a non-empty string"})

    # suggestions = recommend_top_k_steps(MODEL, TOKENIZER, prompt, top_k=top_k)
    # return RecommendResponse(prompt=prompt, top_k=top_k, suggestions=suggestions)


# ------------------------------
# Entrypoint for running with `python pvs_step_recommender_api.py`
# ------------------------------
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("pvs_step_recommender_api:app", host="0.0.0.0", port=int(os.environ.get("PORT", 8000)), reload=False)