rdune71 commited on
Commit
75f72a7
·
1 Parent(s): 5082283

Refactor LLM module to support multiple providers with unified interface

Browse files
Files changed (1) hide show
  1. core/llm.py +158 -40
core/llm.py CHANGED
@@ -1,57 +1,175 @@
 
 
1
  import requests
 
 
2
  import openai
3
  from utils.config import config
4
 
5
- class LLMClient:
6
- def __init__(self, provider="ollama", model_name=None):
7
- self.provider = provider
8
- self.model_name = model_name or config.local_model_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Set up OpenAI client for Hugging Face endpoint
11
- self.hf_client = openai.OpenAI(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  base_url=config.hf_api_url,
13
  api_key=config.hf_token
14
  )
15
 
16
- def generate(self, prompt, max_tokens=8192, stream=True):
17
- if self.provider == "ollama":
18
- return self._generate_ollama(prompt, max_tokens, stream)
19
- elif self.provider == "huggingface":
20
- return self._generate_hf(prompt, max_tokens, stream)
21
- else:
22
- raise ValueError(f"Unsupported provider: {self.provider}")
23
-
24
- def _generate_ollama(self, prompt, max_tokens, stream):
25
- url = f"{config.ollama_host}/api/generate"
26
- payload = {
27
- "model": self.model_name,
28
- "prompt": prompt,
29
- "stream": stream
30
- }
 
 
 
 
 
31
 
32
- try:
33
- with requests.post(url, json=payload, stream=stream) as response:
34
- if response.status_code != 200:
35
- raise Exception(f"Ollama API error: {response.text}")
36
-
37
- if stream:
38
- return (chunk.decode("utf-8") for chunk in response.iter_content())
39
- else:
40
- return response.json()["response"]
41
- except Exception as e:
42
- raise Exception(f"Ollama request failed: {e}")
43
-
44
- def _generate_hf(self, prompt, max_tokens, stream):
45
- try:
46
- response = self.hf_client.chat.completions.create(
47
  model=self.model_name,
48
  messages=[{"role": "user", "content": prompt}],
49
  max_tokens=max_tokens,
50
  stream=stream
51
  )
 
52
  if stream:
53
- return (chunk.choices[0].delta.content or "" for chunk in response)
 
 
 
 
 
54
  else:
55
- return response.choices[0].text
56
- except Exception as e:
57
- raise Exception(f"Hugging Face API error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
  import requests
4
+ from abc import ABC, abstractmethod
5
+ from typing import Union, Generator
6
  import openai
7
  from utils.config import config
8
 
9
+ class LLMProvider(ABC):
10
+ """Abstract base class for all LLM providers"""
11
+
12
+ def __init__(self, model_name: str, timeout: int = 30, retries: int = 3):
13
+ self.model_name = model_name
14
+ self.timeout = timeout
15
+ self.retries = retries
16
+
17
+ @abstractmethod
18
+ def generate(self, prompt: str, max_tokens: int = 500, stream: bool = False) -> Union[str, Generator[str, None, None]]:
19
+ """Generate text completion - must be implemented by subclasses"""
20
+ pass
21
+
22
+ def _retry_request(self, func, *args, **kwargs):
23
+ """Generic retry wrapper with exponential backoff"""
24
+ last_exception = None
25
+ for attempt in range(self.retries + 1):
26
+ try:
27
+ return func(*args, **kwargs)
28
+ except Exception as e:
29
+ last_exception = e
30
+ if attempt < self.retries:
31
+ time.sleep(1 * (2 ** attempt)) # Exponential backoff
32
+ continue
33
+ raise last_exception
34
+
35
 
36
+ class OllamaProvider(LLMProvider):
37
+ def __init__(self, model_name: str, host: str = None, timeout: int = 30, retries: int = 3):
38
+ super().__init__(model_name, timeout, retries)
39
+ self.host = host or config.ollama_host
40
+ self.headers = {
41
+ "ngrok-skip-browser-warning": "true",
42
+ "User-Agent": "AI-Life-Coach"
43
+ }
44
+
45
+ def generate(self, prompt: str, max_tokens: int = 500, stream: bool = False) -> Union[str, Generator[str, None, None]]:
46
+ def _make_request():
47
+ url = f"{self.host}/api/generate"
48
+ payload = {
49
+ "model": self.model_name,
50
+ "prompt": prompt,
51
+ "stream": stream,
52
+ "options": {
53
+ "num_predict": max_tokens
54
+ }
55
+ }
56
+
57
+ response = requests.post(
58
+ url,
59
+ json=payload,
60
+ headers=self.headers,
61
+ timeout=self.timeout,
62
+ stream=stream
63
+ )
64
+
65
+ if response.status_code != 200:
66
+ raise Exception(f"Ollama API error: {response.text}")
67
+
68
+ if stream:
69
+ def stream_response():
70
+ for line in response.iter_lines():
71
+ if line:
72
+ try:
73
+ data = json.loads(line.decode('utf-8'))
74
+ if 'response' in data:
75
+ yield data['response']
76
+ except:
77
+ continue
78
+ return stream_response()
79
+ else:
80
+ return response.json()["response"]
81
+
82
+ return self._retry_request(_make_request)
83
+
84
+
85
+ class HuggingFaceProvider(LLMProvider):
86
+ def __init__(self, model_name: str, timeout: int = 30, retries: int = 3):
87
+ super().__init__(model_name, timeout, retries)
88
+ self.client = openai.OpenAI(
89
  base_url=config.hf_api_url,
90
  api_key=config.hf_token
91
  )
92
 
93
+ def generate(self, prompt: str, max_tokens: int = 500, stream: bool = False) -> Union[str, Generator[str, None, None]]:
94
+ def _make_request():
95
+ response = self.client.chat.completions.create(
96
+ model=self.model_name,
97
+ messages=[{"role": "user", "content": prompt}],
98
+ max_tokens=max_tokens,
99
+ stream=stream
100
+ )
101
+
102
+ if stream:
103
+ def stream_response():
104
+ for chunk in response:
105
+ content = chunk.choices[0].delta.content
106
+ if content:
107
+ yield content
108
+ return stream_response()
109
+ else:
110
+ return response.choices[0].message.content
111
+
112
+ return self._retry_request(_make_request)
113
 
114
+
115
+ class OpenAIProvider(LLMProvider):
116
+ def __init__(self, model_name: str, api_key: str = None, timeout: int = 30, retries: int = 3):
117
+ super().__init__(model_name, timeout, retries)
118
+ self.client = openai.OpenAI(api_key=api_key or config.openai_api_key)
119
+
120
+ def generate(self, prompt: str, max_tokens: int = 500, stream: bool = False) -> Union[str, Generator[str, None, None]]:
121
+ def _make_request():
122
+ response = self.client.chat.completions.create(
 
 
 
 
 
 
123
  model=self.model_name,
124
  messages=[{"role": "user", "content": prompt}],
125
  max_tokens=max_tokens,
126
  stream=stream
127
  )
128
+
129
  if stream:
130
+ def stream_response():
131
+ for chunk in response:
132
+ content = chunk.choices[0].delta.content
133
+ if content:
134
+ yield content
135
+ return stream_response()
136
  else:
137
+ return response.choices[0].message.content
138
+
139
+ return self._retry_request(_make_request)
140
+
141
+
142
+ class LLMClient:
143
+ PROVIDER_MAP = {
144
+ "ollama": OllamaProvider,
145
+ "huggingface": HuggingFaceProvider,
146
+ "openai": OpenAIProvider
147
+ }
148
+
149
+ def __init__(self, provider: str = "ollama", model_name: str = None, **provider_kwargs):
150
+ self.provider_name = provider.lower()
151
+ self.model_name = model_name or self._get_default_model()
152
+
153
+ if self.provider_name not in self.PROVIDER_MAP:
154
+ raise ValueError(f"Unsupported provider: {provider}")
155
+
156
+ provider_class = self.PROVIDER_MAP[self.provider_name]
157
+ self.provider = provider_class(self.model_name, **provider_kwargs)
158
+
159
+ def _get_default_model(self) -> str:
160
+ """Get default model based on provider"""
161
+ defaults = {
162
+ "ollama": config.local_model_name,
163
+ "huggingface": "meta-llama/Meta-Llama-3-8B-Instruct",
164
+ "openai": "gpt-3.5-turbo"
165
+ }
166
+ return defaults.get(self.provider_name, "mistral")
167
+
168
+ def generate(self, prompt: str, max_tokens: int = 500, stream: bool = False) -> Union[str, Generator[str, None, None]]:
169
+ """Unified generate method that delegates to provider"""
170
+ return self.provider.generate(prompt, max_tokens, stream)
171
+
172
+ @classmethod
173
+ def get_available_providers(cls) -> list:
174
+ """Return list of supported providers"""
175
+ return list(cls.PROVIDER_MAP.keys())