Shortcuts

Source code for minedojo.data.reddit_dataset

from __future__ import annotations

import json
import os

import praw

from .download import download as dl
from .download import get_fn


[docs]class RedditDataset: """ Class for MineDojo Reddit Database API. We follow PyTorch Dataset format but without actually inheriting from PyTorch dataset to keep the framework general. See https://praw.readthedocs.io/en/stable/getting_started/quick_start.html for setting up ``client_id``, ``cliend_secret`` and ``user_agent``. Args: download: If ``True`` and there is no existing cache directory, the data will be downloaded automatically. download_dir: Directory path where the downloaded data will be saved. Default: ``~/.minedojo/``. client_id: The client ID to access Reddit’s API as a script application. client_secret: The client secret to access Reddit’s API as a script application. user_agent: A unique identifier that helps Reddit determine the source of network requests. max_comments: Maximum number of comments to load. Examples: >>> from minedojo.data import RedditDataset >>> reddit_dataset = RedditDataset(client_id={your_client_id}, client_secret={your_client_secret}, user_agent={your_user_agent}) >>> print(reddit_dataset[0].keys()) dict_keys(['id', 'title', 'link', 'score', 'num_comments', 'created_utc', 'type', 'content', 'comments']) """ def __init__( self, *, download: bool = True, download_dir: None | str = None, client_id: str = None, client_secret: str = None, user_agent: str = None, max_comments: int = 100, ): if download_dir is None: download_dir = os.path.join(os.path.expanduser("~"), ".minedojo") if download: self.root = dl("reddit", download_dir) else: self.root, _, url = get_fn("reddit", download_dir) assert os.path.exists(self.root), ( f"Reddit data file {self.root} does not exist. " "Please set download=True or you can manually " f"download it from {url}." ) with open(self.root, "r") as f: self.data = json.load(f) if client_id is None or client_secret is None or user_agent is None: raise RedditAPIKeyNotSpecifiedError self.api = praw.Reddit( client_id=client_id, client_secret=client_secret, user_agent=user_agent, check_for_async=False, ) self.max_comments = max_comments
[docs] def get_metadata(self, post_id: str, post_type: str) -> dict: """Get post metadata using PRAW. Args: post_id: The unique, base36 ID of a Reddit post. post_type: The type of the post, either "image", "text", "video" or "link". Return: A dictionary containing the metadata of the post. - id(``str``) - The unique, base36 Reddit post ID. - title(``str``) - The title of the Reddit post. - link(``str``) - The url of the Reddit post. - score(``int``) - The score of the Reddit post. - num_comments(``int``) - The number of comments under the Reddit post. Does not account for deleted comments. - created_utc(``int``) - The date and time the Reddit post was created, in UTC format. - type(``str``) - The type of the post, either "image", "text", "video" or "link". - content(``str``) - If text type post, text in post body. Otherwise, the media source url or website link. - comments(``list[dict]``) - id(``str``) - The unique base36 comment ID. - parent_id(``str``) - The ID of the comment's parent in the nested comment tree. - content(``str``) - The text in comment body. """ post = self.api.submission(id=post_id) metadata = { "id": post.id, "title": post.title, "link": f"https://www.reddit.com/r/Minecraft/comments/{post.id}", "score": post.score, "num_comments": post.num_comments, "created_utc": post.created_utc, "type": post_type, "content": post.selftext if post_type == "text" else post.url, "comments": self.get_comments(post), } return metadata
[docs] def get_comments(self, post: praw.models.Submission) -> list[dict]: comments = [] comment_queue = post.comments[:] while len(comments) < self.max_comments and comment_queue: comment = comment_queue.pop(0) if isinstance(comment, praw.models.MoreComments): comment_queue.extend(comment.comments()) continue comments.append( { "id": comment.id, "parent_id": comment.parent_id[3:], "content": comment.body, } ) return comments
def __len__(self): return len(self.data) def __getitem__(self, idx): return self.get_metadata( post_id=self.data[idx]["id"], post_type=self.data[idx]["type"] )
[docs]class RedditAPIKeyNotSpecifiedError(Exception): def __init__(self): self.message = ( 'You need to specify "client_id", "client_secret" and "user_agent" for Reddit API. ' "You can refer to https://praw.readthedocs.io/en/stable/getting_started/quick_start.html " "for the instructions of obtaining Reddit API keys." ) super().__init__(self.message)