-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from hrszpuk/dev
Minimum Viable Product - Complete
- Loading branch information
Showing
31 changed files
with
631 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
""" | ||
This is the core module for parrot. This module contains all the "business" logic used in the application. | ||
""" | ||
from . import manager | ||
from . import interface | ||
from .controller import Controller | ||
from .model_explorer import ModelExplorer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import sqlite3 | ||
from datetime import datetime | ||
|
||
|
||
class ChatManager: | ||
"""Manages chat sessions and message history.""" | ||
|
||
def __init__(self, db_path='chat_history.db'): | ||
self.db_path = db_path | ||
self._connect() | ||
|
||
def _connect(self): | ||
"""Connect to the SQLite database and create necessary tables.""" | ||
self.conn = sqlite3.connect(self.db_path) | ||
self.cursor = self.conn.cursor() | ||
|
||
self.cursor.execute(""" | ||
CREATE TABLE IF NOT EXISTS chats ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
chat_name TEXT NOT NULL, | ||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, | ||
model_name TEXT NOT NULL | ||
) | ||
""") | ||
|
||
self.conn.commit() | ||
|
||
def create_chat_session(self, chat_name, model_name) -> int: | ||
"""Creates a new chat session and the corresponding message table.""" | ||
timestamp = datetime.now() | ||
|
||
self.cursor.execute(""" | ||
INSERT INTO chats (chat_name, timestamp, model_name) | ||
VALUES (?, ?, ?) | ||
""", (chat_name, timestamp, model_name)) | ||
|
||
chat_id = self.cursor.lastrowid | ||
table_name = f"chat_{chat_id}_messages" | ||
|
||
self.cursor.execute(f""" | ||
CREATE TABLE IF NOT EXISTS {table_name} ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
chat_id INTEGER NOT NULL, | ||
sender TEXT NOT NULL, | ||
message TEXT NOT NULL, | ||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, | ||
FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE | ||
) | ||
""") | ||
|
||
self.conn.commit() | ||
|
||
return chat_id | ||
|
||
def get_chat_sessions(self): | ||
"""Fetches a list of all chat sessions.""" | ||
self.cursor.execute("SELECT id, chat_name, timestamp, model_name FROM chats") | ||
return self.cursor.fetchall() | ||
|
||
def add_message(self, chat_id, sender, message): | ||
"""Adds a message to the chat session.""" | ||
table_name = f"chat_{chat_id}_messages" | ||
timestamp = datetime.now() | ||
|
||
self.cursor.execute(f""" | ||
INSERT INTO {table_name} (chat_id, sender, message, timestamp) | ||
VALUES (?, ?, ?, ?) | ||
""", (chat_id, sender, message, timestamp)) | ||
|
||
self.conn.commit() | ||
|
||
def edit_message(self, chat_id, message_id, new_message): | ||
"""Edits an existing message in the chat session.""" | ||
table_name = f"chat_{chat_id}_messages" | ||
|
||
self.cursor.execute(f""" | ||
UPDATE {table_name} | ||
SET message = ?, timestamp = ? | ||
WHERE id = ? | ||
""", (new_message, datetime.now(), message_id)) | ||
|
||
self.conn.commit() | ||
|
||
def delete_chat_session(self, chat_id): | ||
"""Deletes the chat session and its associated messages.""" | ||
self.cursor.execute("DELETE FROM chats WHERE id = ?", (chat_id,)) | ||
|
||
table_name = f"chat_{chat_id}_messages" | ||
self.cursor.execute(f"DROP TABLE IF EXISTS {table_name}") | ||
|
||
self.conn.commit() | ||
|
||
def get_messages(self, chat_id): | ||
"""Fetches all messages for a specific chat session.""" | ||
table_name = f"chat_{chat_id}_messages" | ||
self.cursor.execute(f"SELECT id, sender, message, timestamp FROM {table_name}") | ||
return self.cursor.fetchall() | ||
|
||
def close(self): | ||
"""Closes the connection to the database.""" | ||
self.conn.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from .chat_manager import ChatManager | ||
|
||
|
||
class Controller: | ||
"""Handles business logic for the application and acts as the interface between the GUI and the ChatManager.""" | ||
|
||
def __init__(self): | ||
self.chat_manager = ChatManager() | ||
self.current_chat_id = None | ||
|
||
def start_new_chat(self, chat_name, model_name="default_model"): | ||
self.current_chat_id = self.chat_manager.create_chat_session(chat_name, model_name) | ||
return self.current_chat_id | ||
|
||
def set_current_chat(self, chat_id): | ||
self.current_chat_id = chat_id | ||
|
||
def get_chat_sessions(self): | ||
return self.chat_manager.get_chat_sessions() | ||
|
||
def get_current_chat_history(self): | ||
if self.current_chat_id is None: | ||
return [] | ||
|
||
return self.chat_manager.get_messages(self.current_chat_id) | ||
|
||
def add_message_to_current_chat(self, sender, message): | ||
if self.current_chat_id is None: | ||
raise ValueError("No current chat session set.") | ||
|
||
self.chat_manager.add_message(self.current_chat_id, sender, message) | ||
|
||
def close(self): | ||
self.chat_manager.close() |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from huggingface_hub import HfApi | ||
|
||
|
||
class ModelExplorer: | ||
"""Handles both exploring popular models and searching for specific models on Hugging Face.""" | ||
|
||
def __init__(self): | ||
self.api = HfApi() | ||
|
||
def get_popular_models(self, limit=10): | ||
"""Returns a list of popular models from Hugging Face.""" | ||
popular_models = self.api.list_models(sort='downloads', limit=limit) | ||
model_names = [model.id for model in popular_models] | ||
return model_names | ||
|
||
def search_model(self, query: str, limit=10): | ||
"""Searches Hugging Face models based on a query.""" | ||
search_results = self.api.list_models(search=query, limit=limit) | ||
model_names = [model.id for model in search_results] | ||
return model_names |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import os | ||
import shutil | ||
from transformers import AutoModel, AutoTokenizer | ||
|
||
|
||
class ModelManager: | ||
"""Manages installing new open-source models, keeping track of installed models, and loading models onto the CPU/GPU.""" | ||
|
||
def __init__(self, models_dir='models'): | ||
self.models_dir = models_dir | ||
os.makedirs(self.models_dir, exist_ok=True) # Ensure models directory exists | ||
|
||
def download_model(self, model_name: str): | ||
"""Downloads the specified model from Hugging Face and saves it locally.""" | ||
model_path = os.path.join(self.models_dir, model_name) | ||
if not os.path.exists(model_path): | ||
print(f"Downloading model: {model_name}") | ||
AutoModel.from_pretrained(model_name).save_pretrained(model_path) | ||
AutoTokenizer.from_pretrained(model_name).save_pretrained(model_path) | ||
else: | ||
print(f"Model {model_name} is already downloaded.") | ||
|
||
def load_model(self, model_name: str): | ||
"""Loads the specified model from disk to CPU/GPU.""" | ||
model_path = os.path.join(self.models_dir, model_name) | ||
if os.path.exists(model_path): | ||
print(f"Loading model: {model_name}") | ||
model = AutoModel.from_pretrained(model_path) | ||
tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
return model, tokenizer | ||
else: | ||
raise FileNotFoundError(f"Model {model_name} not found. Please download it first.") | ||
|
||
def delete_model(self, model_name: str): | ||
"""Deletes the specified model from disk.""" | ||
model_path = os.path.join(self.models_dir, model_name) | ||
if os.path.exists(model_path): | ||
shutil.rmtree(model_path) | ||
print(f"Deleted model: {model_name}") | ||
else: | ||
print(f"Model {model_name} not found.") | ||
|
||
def list_models(self): | ||
"""Lists all installed models.""" | ||
return os.listdir(self.models_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,95 @@ | ||
from datetime import datetime | ||
|
||
from PySide6.QtWidgets import * | ||
|
||
from parrot.gui.widgets.chat_window import ChatWindow | ||
from parrot.core import Controller | ||
from parrot.gui.widgets.chat import ChatWindow | ||
from parrot.gui.widgets.my_copilots import ModelMenu | ||
from parrot.gui.widgets.settings import SettingsMenu | ||
from parrot.gui.widgets.sidebar import Sidebar | ||
|
||
|
||
class MainWindow(QMainWindow): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.controller = Controller() | ||
|
||
self.setWindowTitle("Parrot") | ||
self.setGeometry(600, 100, 1600, 900) # Position and size (x, y, width, height) | ||
self.setGeometry(600, 100, 1400, 800) | ||
self.show() | ||
self.setup_ui() | ||
|
||
def setup_ui(self): | ||
main_layout = QHBoxLayout() | ||
self.main_layout = QHBoxLayout() | ||
|
||
self.sidebar_widget = Sidebar() | ||
self.main_layout.addWidget(self.sidebar_widget) | ||
|
||
self.sidebar_widget.chat_list.new_chat_button_pressed.connect(self.create_new_chat) | ||
self.sidebar_widget.bottom.settings_clicked.connect(self.open_settings_menu) | ||
self.sidebar_widget.bottom.ai_model_manager_clicked.connect(self.open_copilots_menu) | ||
|
||
self.chat_window = ChatWindow() | ||
self.main_layout.addWidget(self.chat_window) | ||
|
||
self.settings_menu = SettingsMenu() | ||
self.settings_menu.settings_closed.connect(self.show_main_content) | ||
self.settings_menu.setVisible(False) | ||
self.main_layout.addWidget(self.settings_menu) | ||
|
||
chat_window = ChatWindow() | ||
chat_window.add_message("Hello?", is_user=False) | ||
chat_window.add_message("Hello?", is_user=True) | ||
chat_window.add_message("Hello?", is_user=False) | ||
main_layout.addWidget(chat_window) | ||
self.copilots_menu = ModelMenu() | ||
self.copilots_menu.my_copilots_closed.connect(self.show_main_content) | ||
self.copilots_menu.setVisible(False) | ||
self.main_layout.addWidget(self.copilots_menu) | ||
|
||
widget = QWidget() | ||
widget.setLayout(main_layout) | ||
widget.setLayout(self.main_layout) | ||
|
||
self.setCentralWidget(widget) | ||
|
||
def create_new_chat(self): | ||
chat_name = f"New Chat ({len(self.controller.get_chat_sessions())})" | ||
self.sidebar_widget.chat_list.add_new_chat(chat_name) | ||
|
||
self.controller.start_new_chat(chat_name) | ||
|
||
self.chat_window.chat_top_bar.title = chat_name | ||
self.chat_window.message_list.clear_messages() | ||
self.chat_window.chat_box.text_input.setText("") | ||
|
||
def send_message(self): | ||
message = self.chat_window.chat_box.get_text() | ||
|
||
self.controller.add_message_to_current_chat("user", message) | ||
self.chat_window.add_message(message, is_user=True, loading=False) | ||
self.chat_window.chat_box.clear_text() | ||
|
||
def load_chat_history(self, chat_id): | ||
self.controller.set_current_chat(chat_id) | ||
|
||
chat_history = self.controller.get_current_chat_history() | ||
|
||
self.chat_window.message_list.clear_messages() | ||
for message in chat_history: | ||
self.chat_window.message_list.add_message(message[1], message[2]) # (sender, message) | ||
|
||
def closeEvent(self, event): | ||
self.controller.close() | ||
event.accept() | ||
|
||
def open_settings_menu(self): | ||
self.chat_window.setVisible(False) | ||
self.sidebar_widget.setVisible(False) | ||
|
||
self.settings_menu.setVisible(True) | ||
|
||
def open_copilots_menu(self): | ||
self.chat_window.setVisible(False) | ||
self.sidebar_widget.setVisible(False) | ||
|
||
self.copilots_menu.setVisible(True) | ||
|
||
def show_main_content(self): | ||
self.settings_menu.setVisible(False) | ||
self.copilots_menu.setVisible(False) | ||
self.chat_window.setVisible(True) | ||
self.sidebar_widget.setVisible(True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .chat_box import ChatBox | ||
from .chat_message import ChatMessage | ||
from .chat_message_list import ChatMessageList | ||
from .chat_window import ChatWindow |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.