Spaces:
Sleeping
Sleeping
| import os | |
| from pathlib import Path | |
| import PIL | |
| from llmlib.base_llm import LLM, Message | |
| import pytest | |
| def assert_model_knows_capital_of_france(model: LLM) -> None: | |
| response: str = model.complete_msgs2( | |
| msgs=[Message(role="user", msg="What is the capital of France?")] | |
| ) | |
| assert "paris" in response.lower() | |
| def assert_model_can_answer_batch_of_text_prompts(model: LLM) -> None: | |
| prompts = [ | |
| "What is the capital of France?", | |
| "What continent is south of Europe?", | |
| "What are the two tallest mountains in the world?", | |
| ] | |
| batch = [[Message.from_prompt(prompt)] for prompt in prompts] | |
| responses = model.complete_batch(batch=batch) | |
| assert len(responses) == 3 | |
| assert "paris" in responses[0].lower() | |
| assert "africa" in responses[1].lower() | |
| assert "everest" in responses[2].lower() | |
| def assert_model_can_answer_batch_of_img_prompts(model: LLM) -> None: | |
| batch = [ | |
| [pyramid_message()], | |
| [forest_message()], | |
| [fish_message()], | |
| ] | |
| responses = model.complete_batch(batch=batch) | |
| assert len(responses) == 3 | |
| assert "pyramid" in responses[0].lower() | |
| assert "forest" in responses[1].lower() | |
| assert "fish" in responses[2].lower() | |
| def assert_model_rejects_unsupported_batches(model: LLM) -> None: | |
| mixed_textonly_and_img_batch = [ | |
| [Message.from_prompt("What is the capital of France?")], | |
| [pyramid_message()], | |
| ] | |
| err_msg = "Batch must contain an image in every entry or none at all." | |
| with pytest.raises(ValueError, match=err_msg): | |
| model.complete_batch(mixed_textonly_and_img_batch) | |
| def assert_model_recognizes_pyramid_in_image(model: LLM): | |
| msg = pyramid_message() | |
| answer: str = model.complete_msgs2(msgs=[msg]) | |
| assert "pyramid" in answer.lower() | |
| def assert_model_recognizes_afd_in_video(model: LLM): | |
| video_path = file_for_test("video.mp4") | |
| question = "Describe the video in english" | |
| answer: str = model.video_prompt(video_path, question) | |
| assert "alternative für deutschland" in answer.lower(), answer | |
| def get_mona_lisa_completion(model: LLM) -> str: | |
| msg: Message = mona_lisa_message() | |
| answer: str = model.complete_msgs2(msgs=[msg]) | |
| return answer | |
| def mona_lisa_message() -> Message: | |
| _, img = mona_lisa_filename_and_img() | |
| prompt = "What is in the image?" | |
| msg = Message(role="user", msg=prompt, img=img, img_name="") | |
| return msg | |
| def pyramid_message() -> Message: | |
| img_name = "pyramid.jpg" | |
| img = get_test_img(img_name) | |
| msg = Message(role="user", msg="What is in the image?", img=img, img_name="") | |
| return msg | |
| def forest_message() -> Message: | |
| img_name = "forest.jpg" | |
| img = get_test_img(img_name) | |
| msg = Message( | |
| role="user", msg="Describe what you see in the picture.", img=img, img_name="" | |
| ) | |
| return msg | |
| def fish_message() -> Message: | |
| img_name = "fish.jpg" | |
| img = get_test_img(img_name) | |
| msg = Message( | |
| role="user", | |
| msg="What animal is depicted and where does it live?", | |
| img=img, | |
| img_name="", | |
| ) | |
| return msg | |
| def mona_lisa_filename_and_img() -> tuple[str, PIL.Image.Image]: | |
| img_name = "mona-lisa.png" | |
| img = get_test_img(img_name) | |
| return img_name, img | |
| def get_test_img(name: str) -> PIL.Image.Image: | |
| path = file_for_test(name) | |
| return PIL.Image.open(path) | |
| def file_for_test(name: str) -> Path: | |
| return Path(__file__).parent.parent / "test-files" / name | |
| def is_ci() -> bool: | |
| is_ci_str: str = os.environ.get("CI", "false").lower() | |
| return is_ci_str != "false" | |