rdune71 commited on
Commit
7878c29
·
1 Parent(s): 9663d50

Fix Hugging Face provider by removing problematic proxies parameter

Browse files
Files changed (1) hide show
  1. core/llm.py +75 -153
core/llm.py CHANGED
@@ -1,180 +1,102 @@
1
- import json
2
- import time
3
- import requests
4
- from abc import ABC, abstractmethod
5
- from typing import Union, Generator
6
  import openai
7
- from utils.config import config
 
 
 
8
 
9
- class LLMProvider(ABC):
10
- """Abstract base class for all LLM providers"""
 
11
  def __init__(self, model_name: str, timeout: int = 30, retries: int = 3):
12
  self.model_name = model_name
13
  self.timeout = timeout
14
  self.retries = retries
15
 
16
- @abstractmethod
17
- def generate(self, prompt: str, max_tokens: int = 500, stream: bool = False) -> Union[str, Generator[str, None, None]]:
18
- """Generate text completion - must be implemented by subclasses"""
19
- pass
20
 
21
- def _retry_request(self, func, *args, **kwargs):
22
- """Generic retry wrapper with exponential backoff"""
23
- last_exception = None
24
- for attempt in range(self.retries + 1):
 
 
 
 
 
 
 
 
25
  try:
26
- return func(*args, **kwargs)
 
 
 
 
 
 
27
  except Exception as e:
28
- last_exception = e
29
- if attempt < self.retries:
30
- time.sleep(1 * (2 ** attempt)) # Exponential backoff
31
- continue
32
- raise last_exception
33
 
34
  class OllamaProvider(LLMProvider):
35
- def __init__(self, model_name: str, host: str = None, timeout: int = 30, retries: int = 3):
36
- super().__init__(model_name, timeout, retries)
37
- self.host = host or config.ollama_host
38
- self.headers = {
39
- "ngrok-skip-browser-warning": "true",
40
- "User-Agent": "AI-Life-Coach"
41
- }
42
-
43
- def generate(self, prompt: str, max_tokens: int = 500, stream: bool = False) -> Union[str, Generator[str, None, None]]:
44
- def _make_request():
45
- # Use the chat endpoint instead of generate for better compatibility
46
- url = f"{self.host}/api/chat"
47
- payload = {
48
- "model": self.model_name,
49
- "messages": [{"role": "user", "content": prompt}],
50
- "stream": stream,
51
- "options": {
52
- "num_predict": max_tokens
53
- }
54
- }
55
-
56
- response = requests.post(
57
- url,
58
- json=payload,
59
- headers=self.headers,
60
- timeout=self.timeout,
61
- stream=stream
62
- )
63
-
64
- if response.status_code != 200:
65
- raise Exception(f"Ollama API error: {response.text}")
66
-
67
- if stream:
68
- def stream_response():
69
- for line in response.iter_lines():
70
- if line:
71
- try:
72
- data = json.loads(line.decode('utf-8'))
73
- # Handle chat endpoint response format
74
- if 'message' in data and 'content' in data['message']:
75
- yield data['message']['content']
76
- except:
77
- continue
78
- return stream_response()
79
- else:
80
- # Handle chat endpoint response format
81
- data = response.json()
82
- if 'message' in data and 'content' in data['message']:
83
- return data['message']['content']
84
- else:
85
- raise Exception("Unexpected response format from Ollama")
86
-
87
- # Fixed: Moved return outside the _make_request function
88
- return self._retry_request(_make_request)
89
-
90
- class HuggingFaceProvider(LLMProvider):
91
  def __init__(self, model_name: str, timeout: int = 30, retries: int = 3):
92
  super().__init__(model_name, timeout, retries)
93
  self.client = openai.OpenAI(
94
- base_url=config.hf_api_url,
95
- api_key=config.hf_token
96
  )
97
 
98
- def generate(self, prompt: str, max_tokens: int = 500, stream: bool = False) -> Union[str, Generator[str, None, None]]:
99
- def _make_request():
100
- response = self.client.chat.completions.create(
101
- model=self.model_name,
102
- messages=[{"role": "user", "content": prompt}],
103
- max_tokens=max_tokens,
104
- stream=stream
105
- )
106
-
107
- if stream:
108
- def stream_response():
109
- for chunk in response:
110
- content = chunk.choices[0].delta.content
111
- if content:
112
- yield content
113
- return stream_response()
114
- else:
115
  return response.choices[0].message.content
116
-
117
- # Fixed: Moved return outside the _make_request function
118
- return self._retry_request(_make_request)
 
 
 
119
 
120
  class OpenAIProvider(LLMProvider):
121
- def __init__(self, model_name: str, api_key: str = None, timeout: int = 30, retries: int = 3):
122
  super().__init__(model_name, timeout, retries)
123
- self.client = openai.OpenAI(api_key=api_key or config.openai_api_key)
124
 
125
- def generate(self, prompt: str, max_tokens: int = 500, stream: bool = False) -> Union[str, Generator[str, None, None]]:
126
- def _make_request():
127
- response = self.client.chat.completions.create(
128
- model=self.model_name,
129
- messages=[{"role": "user", "content": prompt}],
130
- max_tokens=max_tokens,
131
- stream=stream
132
- )
133
-
134
- if stream:
135
- def stream_response():
136
- for chunk in response:
137
- content = chunk.choices[0].delta.content
138
- if content:
139
- yield content
140
- return stream_response()
141
- else:
142
  return response.choices[0].message.content
143
-
144
- # Fixed: Moved return outside the _make_request function
145
- return self._retry_request(_make_request)
 
 
 
146
 
147
- class LLMClient:
148
- PROVIDER_MAP = {
149
- "ollama": OllamaProvider,
150
  "huggingface": HuggingFaceProvider,
 
151
  "openai": OpenAIProvider
152
  }
153
-
154
- def __init__(self, provider: str = "ollama", model_name: str = None, **provider_kwargs):
155
- self.provider_name = provider.lower()
156
- self.model_name = model_name or self._get_default_model()
157
-
158
- if self.provider_name not in self.PROVIDER_MAP:
159
- raise ValueError(f"Unsupported provider: {provider}")
160
-
161
- provider_class = self.PROVIDER_MAP[self.provider_name]
162
- self.provider = provider_class(self.model_name, **provider_kwargs)
163
-
164
- def _get_default_model(self) -> str:
165
- """Get default model based on provider"""
166
- defaults = {
167
- "ollama": config.local_model_name,
168
- "huggingface": "meta-llama/Meta-Llama-3-8B-Instruct",
169
- "openai": "gpt-3.5-turbo"
170
- }
171
- return defaults.get(self.provider_name, "mistral")
172
-
173
- def generate(self, prompt: str, max_tokens: int = 500, stream: bool = False) -> Union[str, Generator[str, None, None]]:
174
- """Unified generate method that delegates to provider"""
175
- return self.provider.generate(prompt, max_tokens, stream)
176
-
177
- @classmethod
178
- def get_available_providers(cls) -> list:
179
- """Return list of supported providers"""
180
- return list(cls.PROVIDER_MAP.keys())
 
 
 
 
 
 
1
  import openai
2
+ import time
3
+ from typing import Dict, Any, List, Optional
4
+ from core.config import config
5
+ import logging
6
 
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class LLMProvider:
10
  def __init__(self, model_name: str, timeout: int = 30, retries: int = 3):
11
  self.model_name = model_name
12
  self.timeout = timeout
13
  self.retries = retries
14
 
15
+ def generate_response(self, messages: List[Dict[str, str]], **kwargs) -> str:
16
+ raise NotImplementedError
 
 
17
 
18
+ class HuggingFaceProvider(LLMProvider):
19
+ def __init__(self, model_name: str, timeout: int = 30, retries: int = 3):
20
+ super().__init__(model_name, timeout, retries)
21
+ # Remove proxies parameter that causes the error
22
+ self.client = openai.OpenAI(
23
+ base_url=config.hf_api_url,
24
+ api_key=config.hf_token
25
+ # Removed: proxies parameter
26
+ )
27
+
28
+ def generate_response(self, messages: List[Dict[str, str]], **kwargs) -> str:
29
+ for attempt in range(self.retries):
30
  try:
31
+ response = self.client.chat.completions.create(
32
+ model=self.model_name,
33
+ messages=messages,
34
+ timeout=self.timeout,
35
+ **kwargs
36
+ )
37
+ return response.choices[0].message.content
38
  except Exception as e:
39
+ logger.error(f"HuggingFace API error (attempt {attempt + 1}/{self.retries}): {e}")
40
+ if attempt == self.retries - 1:
41
+ raise
42
+ time.sleep(2 ** attempt) # Exponential backoff
43
+ return ""
44
 
45
  class OllamaProvider(LLMProvider):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def __init__(self, model_name: str, timeout: int = 30, retries: int = 3):
47
  super().__init__(model_name, timeout, retries)
48
  self.client = openai.OpenAI(
49
+ base_url=config.ollama_host + "/v1",
50
+ api_key="ollama" # Ollama doesn't require an API key
51
  )
52
 
53
+ def generate_response(self, messages: List[Dict[str, str]], **kwargs) -> str:
54
+ for attempt in range(self.retries):
55
+ try:
56
+ response = self.client.chat.completions.create(
57
+ model=self.model_name,
58
+ messages=messages,
59
+ timeout=self.timeout,
60
+ **kwargs
61
+ )
 
 
 
 
 
 
 
 
62
  return response.choices[0].message.content
63
+ except Exception as e:
64
+ logger.error(f"Ollama API error (attempt {attempt + 1}/{self.retries}): {e}")
65
+ if attempt == self.retries - 1:
66
+ raise
67
+ time.sleep(2 ** attempt) # Exponential backoff
68
+ return ""
69
 
70
  class OpenAIProvider(LLMProvider):
71
+ def __init__(self, model_name: str, timeout: int = 30, retries: int = 3):
72
  super().__init__(model_name, timeout, retries)
73
+ self.client = openai.OpenAI(api_key=config.openai_api_key)
74
 
75
+ def generate_response(self, messages: List[Dict[str, str]], **kwargs) -> str:
76
+ for attempt in range(self.retries):
77
+ try:
78
+ response = self.client.chat.completions.create(
79
+ model=self.model_name,
80
+ messages=messages,
81
+ timeout=self.timeout,
82
+ **kwargs
83
+ )
 
 
 
 
 
 
 
 
84
  return response.choices[0].message.content
85
+ except Exception as e:
86
+ logger.error(f"OpenAI API error (attempt {attempt + 1}/{self.retries}): {e}")
87
+ if attempt == self.retries - 1:
88
+ raise
89
+ time.sleep(2 ** attempt) # Exponential backoff
90
+ return ""
91
 
92
+ def get_llm_provider(provider_name: str, model_name: str) -> LLMProvider:
93
+ providers = {
 
94
  "huggingface": HuggingFaceProvider,
95
+ "ollama": OllamaProvider,
96
  "openai": OpenAIProvider
97
  }
98
+
99
+ if provider_name not in providers:
100
+ raise ValueError(f"Unsupported provider: {provider_name}")
101
+
102
+ return providers[provider_name](model_name)